NTU_HTML/hw1/hw1_9.py
2023-09-17 23:52:15 +08:00

120 lines
2.5 KiB
Python

# 26744.0, 27177.0, 26982.0
# [99, 289.36156150772626, -1438.5724638894396, 1665.4574907561664, -1563.7272899602135, -988.4793846879909, -2412.181094351847, 1355.672859940388, -948.9745573599807, -384.40018958384206, 1279.5203042150952, -537.1648580135336, -807.3189828731862]
import random
from datetime import datetime
import matplotlib.pyplot as plt
class PLA():
def __init__(self, data):
self.data = data
self.dim = len(self.data[0]['x'])
self.w = [0] * self.dim
def iterate(self):
'''
randomly picks an example (x_n, y_n) in every iteration.
updates wt if and only if w_t is incorrect on the example.
Inputs:
None
Outputs:
mistake: bool.
'''
index = random.randint(0, len(data)-1)
mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0
if mistake:
self.update(index)
return mistake
def update(self, index):
'''
w_(t+1) = w_t + y_t*x_t
Inputs:
index: which row need to use for updating weight
Outputs:
None
'''
for i in range(self.dim):
self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i]
def read_file():
'''
read numbers from 'hw1_train.dat'
Inputs:
None
Outputs:
data: dict list. each dict has ('x', 'y') pair
x is a list contains 12 numbers.
y is either 1 or -1.
'''
with open('hw1_train.dat') as fp:
data = []
lines = fp.readlines()[:-1]
for line in lines:
numbers = line.split()
x = [ float(i) for i in numbers[:-1] ]
x.insert(0,1)
y = int(numbers[-1])
data.append({
'x': x,
'y': y,
})
return data
def dot(a, b):
'''
dot product
Inputs:
a: a list
b: a list. its length must be equal to len(a)
Outputs:
ans: a number
'''
ans = 0
assert len(a)==len(b)
for i in range(len(a)):
ans += a[i] * b[i]
return ans
def sign(a):
'''
positive || negative
Inputs:
a: a number
Outputs:
ans: 1 or -1, means this is positive or negative
'''
return 1 if a>0 else -1
if __name__ == '__main__':
data = read_file()
log = []
for i in range(1000):
random.seed(datetime.now().timestamp())
pla = PLA(data)
counter = 0
stopping_counter = 0
while 1:
mistake = pla.iterate()
counter += 1
if mistake:
stopping_counter = 0
else:
stopping_counter += 1
if stopping_counter == (5*len(data)):
break
log.append(counter)
sorted_log = sorted(log)
print("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
plt.hist(log, bins=100)
plt.savefig('./hw1_9.png')
plt.show()
print(pla.w)