NTU_HTML/hw1_11.py

120 lines
2.5 KiB
Python

# 25428.0 25891.5 26446.5
# [1002.1399999999994, 223.21665372403237, -1291.848080071952, 1515.5600276924624, -1404.8070879687737, -887.5562485642265, -2129.790874533273, 1268.8912058979067, -893.1255882425288, -358.771941045434, 1161.2662431910446, -499.8519798398823, -806.2229919825804]
import random
from datetime import datetime
import matplotlib.pyplot as plt
class PLA():
def __init__(self, data):
self.data = data
self.dim = len(self.data[0]['x'])
self.w = [0] * self.dim
def iterate(self):
'''
randomly picks an example (x_n, y_n) in every iteration.
updates wt if and only if w_t is incorrect on the example.
Inputs:
None
Outputs:
mistake: bool.
'''
index = random.randint(0, len(data)-1)
mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0
if mistake:
self.update(index)
return mistake
def update(self, index):
'''
w_(t+1) = w_t + y_t*x_t
Inputs:
index: which row need to use for updating weight
Outputs:
None
'''
for i in range(self.dim):
self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i]
def read_file():
'''
read numbers from 'hw1_train.dat'
Inputs:
None
Outputs:
data: dict list. each dict has ('x', 'y') pair
x is a list contains 12 numbers.
y is either 1 or -1.
'''
with open('hw1_train.dat') as fp:
data = []
lines = fp.readlines()[:-1]
for line in lines:
numbers = line.split()
x = [ float(i) for i in numbers[:-1] ]
x.insert(0,11.26)
y = int(numbers[-1])
data.append({
'x': x,
'y': y,
})
return data
def dot(a, b):
'''
dot product
Inputs:
a: a list
b: a list. its length must be equal to len(a)
Outputs:
ans: a number
'''
ans = 0
assert len(a)==len(b)
for i in range(len(a)):
ans += a[i] * b[i]
return ans
def sign(a):
'''
positive || negative
Inputs:
a: a number
Outputs:
ans: 1 or -1, means this is positive or negative
'''
return 1 if a>0 else -1
if __name__ == '__main__':
data = read_file()
log = []
for i in range(1000):
random.seed(datetime.now().timestamp())
pla = PLA(data)
counter = 0
stopping_counter = 0
while 1:
mistake = pla.iterate()
counter += 1
if mistake:
stopping_counter = 0
else:
stopping_counter += 1
if stopping_counter == (5*len(data)):
break
log.append(counter)
sorted_log = sorted(log)
print("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
plt.hist(log, bins=100)
plt.savefig('./hw1_11.png')
plt.show()
print(pla.w)