NTU_HTML/hw1/hw1_10.py
2023-09-17 23:52:15 +08:00

120 lines
2.5 KiB
Python

# 26696.5 26399.0 27074.0
# [1204.8199999999993, 3279.9473790382926, -17453.52050319462, 20734.836596930876, -19276.958435444365, -11703.66716416129, -29009.135804905836, 17108.712660770405, -11751.94058615113, -4730.670939823587, 15953.080371721595, -6169.31607964013, -10440.140127478617]
import random
from datetime import datetime
import matplotlib.pyplot as plt
class PLA():
def __init__(self, data):
self.data = data
self.dim = len(self.data[0]['x'])
self.w = [0] * self.dim
def iterate(self):
'''
randomly picks an example (x_n, y_n) in every iteration.
updates wt if and only if w_t is incorrect on the example.
Inputs:
None
Outputs:
mistake: bool.
'''
index = random.randint(0, len(data)-1)
mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0
if mistake:
self.update(index)
return mistake
def update(self, index):
'''
w_(t+1) = w_t + y_t*x_t
Inputs:
index: which row need to use for updating weight
Outputs:
None
'''
for i in range(self.dim):
self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i]
def read_file():
'''
read numbers from 'hw1_train.dat'
Inputs:
None
Outputs:
data: dict list. each dict has ('x', 'y') pair
x is a list contains 12 numbers.
y is either 1 or -1.
'''
with open('hw1_train.dat') as fp:
data = []
lines = fp.readlines()[:-1]
for line in lines:
numbers = line.split()
x = [ float(i)*11.26 for i in numbers[:-1] ]
x.insert(0,11.26)
y = int(numbers[-1])
data.append({
'x': x,
'y': y,
})
return data
def dot(a, b):
'''
dot product
Inputs:
a: a list
b: a list. its length must be equal to len(a)
Outputs:
ans: a number
'''
ans = 0
assert len(a)==len(b)
for i in range(len(a)):
ans += a[i] * b[i]
return ans
def sign(a):
'''
positive || negative
Inputs:
a: a number
Outputs:
ans: 1 or -1, means this is positive or negative
'''
return 1 if a>0 else -1
if __name__ == '__main__':
data = read_file()
log = []
for i in range(1000):
random.seed(datetime.now().timestamp())
pla = PLA(data)
counter = 0
stopping_counter = 0
while 1:
mistake = pla.iterate()
counter += 1
if mistake:
stopping_counter = 0
else:
stopping_counter += 1
if stopping_counter == (5*len(data)):
break
log.append(counter)
sorted_log = sorted(log)
print("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
plt.hist(log, bins=100)
plt.savefig('./hw1_10.png')
plt.show()
print(pla.w)