NTU_HTML/hw1/hw1_10.py
2023-09-26 21:49:51 +08:00

120 lines
2.6 KiB
Python

# [1092.2199999999993, 3051.922507549896, -15613.841537063543, 18802.99559498076, -17113.540540993912, -10687.976818076537, -26180.86407586591, 15575.121450136705, -10674.54133415231, -4220.047076038129, 14460.732130644228, -5803.682172610006, -9355.444131276588]
import random
from datetime import datetime
import matplotlib.pyplot as plt
class PLA():
def __init__(self, data):
self.data = data
self.dim = len(self.data[0]['x'])
self.w = [0] * self.dim
self.update_time = 0
def iterate(self):
'''
randomly picks an example (x_n, y_n) in every iteration.
updates wt if and only if w_t is incorrect on the example.
Inputs:
None
Outputs:
mistake: bool.
'''
index = random.randint(0, len(data)-1)
mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0
if mistake:
self.update(index)
return mistake
def update(self, index):
'''
w_(t+1) = w_t + y_t*x_t
Inputs:
index: which row need to use for updating weight
Outputs:
None
'''
self.update_time += 1
for i in range(self.dim):
self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i]
def read_file():
'''
read numbers from 'hw1_train.dat'
Inputs:
None
Outputs:
data: dict list. each dict has ('x', 'y') pair
x is a list contains 12 numbers.
y is either 1 or -1.
'''
with open('hw1_train.dat') as fp:
data = []
lines = fp.readlines()[:-1]
for line in lines:
numbers = line.split()
x = [ float(i)*11.26 for i in numbers[:-1] ]
x.insert(0,11.26)
y = int(numbers[-1])
data.append({
'x': x,
'y': y,
})
return data
def dot(a, b):
'''
dot product
Inputs:
a: a list
b: a list. its length must be equal to len(a)
Outputs:
ans: a number
'''
ans = 0
assert len(a)==len(b)
for i in range(len(a)):
ans += a[i] * b[i]
return ans
def sign(a):
'''
positive || negative
Inputs:
a: a number
Outputs:
ans: 1 or -1, means this is positive or negative
'''
return 1 if a>0 else -1
if __name__ == '__main__':
data = read_file()
log = []
for i in range(1000):
random.seed(datetime.now().timestamp())
pla = PLA(data)
stopping_counter = 0
while 1:
mistake = pla.iterate()
if mistake:
stopping_counter = 0
else:
stopping_counter += 1
if stopping_counter == (5*len(data)):
break
log.append(pla.update_time)
sorted_log = sorted(log)
print("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
plt.hist(log, bins=100)
plt.title("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
plt.savefig('./hw1_10.png')
plt.show()
print(pla.w)