120 lines
2.5 KiB
Python
120 lines
2.5 KiB
Python
# 26744.0, 27177.0, 26982.0
|
|
# [99, 289.36156150772626, -1438.5724638894396, 1665.4574907561664, -1563.7272899602135, -988.4793846879909, -2412.181094351847, 1355.672859940388, -948.9745573599807, -384.40018958384206, 1279.5203042150952, -537.1648580135336, -807.3189828731862]
|
|
|
|
import random
|
|
from datetime import datetime
|
|
import matplotlib.pyplot as plt
|
|
|
|
class PLA():
|
|
def __init__(self, data):
|
|
self.data = data
|
|
self.dim = len(self.data[0]['x'])
|
|
self.w = [0] * self.dim
|
|
|
|
def iterate(self):
|
|
'''
|
|
randomly picks an example (x_n, y_n) in every iteration.
|
|
updates wt if and only if w_t is incorrect on the example.
|
|
|
|
Inputs:
|
|
None
|
|
Outputs:
|
|
mistake: bool.
|
|
'''
|
|
index = random.randint(0, len(data)-1)
|
|
mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0
|
|
if mistake:
|
|
self.update(index)
|
|
return mistake
|
|
|
|
def update(self, index):
|
|
'''
|
|
w_(t+1) = w_t + y_t*x_t
|
|
|
|
Inputs:
|
|
index: which row need to use for updating weight
|
|
Outputs:
|
|
None
|
|
'''
|
|
for i in range(self.dim):
|
|
self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i]
|
|
|
|
def read_file():
|
|
'''
|
|
read numbers from 'hw1_train.dat'
|
|
|
|
Inputs:
|
|
None
|
|
Outputs:
|
|
data: dict list. each dict has ('x', 'y') pair
|
|
x is a list contains 12 numbers.
|
|
y is either 1 or -1.
|
|
'''
|
|
with open('hw1_train.dat') as fp:
|
|
data = []
|
|
lines = fp.readlines()[:-1]
|
|
for line in lines:
|
|
numbers = line.split()
|
|
x = [ float(i) for i in numbers[:-1] ]
|
|
x.insert(0,1)
|
|
y = int(numbers[-1])
|
|
|
|
data.append({
|
|
'x': x,
|
|
'y': y,
|
|
})
|
|
return data
|
|
|
|
def dot(a, b):
|
|
'''
|
|
dot product
|
|
|
|
Inputs:
|
|
a: a list
|
|
b: a list. its length must be equal to len(a)
|
|
Outputs:
|
|
ans: a number
|
|
'''
|
|
ans = 0
|
|
assert len(a)==len(b)
|
|
for i in range(len(a)):
|
|
ans += a[i] * b[i]
|
|
return ans
|
|
|
|
def sign(a):
|
|
'''
|
|
positive || negative
|
|
|
|
Inputs:
|
|
a: a number
|
|
Outputs:
|
|
ans: 1 or -1, means this is positive or negative
|
|
'''
|
|
return 1 if a>0 else -1
|
|
|
|
if __name__ == '__main__':
|
|
data = read_file()
|
|
log = []
|
|
for i in range(1000):
|
|
random.seed(datetime.now().timestamp())
|
|
pla = PLA(data)
|
|
|
|
counter = 0
|
|
stopping_counter = 0
|
|
while 1:
|
|
mistake = pla.iterate()
|
|
counter += 1
|
|
if mistake:
|
|
stopping_counter = 0
|
|
else:
|
|
stopping_counter += 1
|
|
if stopping_counter == (5*len(data)):
|
|
break
|
|
log.append(counter)
|
|
|
|
sorted_log = sorted(log)
|
|
print("medium: {}".format((sorted_log[499]+sorted_log[500])/2))
|
|
plt.hist(log, bins=100)
|
|
plt.savefig('./hw1_9.png')
|
|
plt.show()
|
|
print(pla.w) |