# [100, 259.3545730056009, -1512.3600819664614, 1716.1152938189734, -1672.818122897704, -970.519199560146, -2403.921658236915, 1394.6452345918879, -980.1711087815969, -389.09446341125613, 1270.5568919545074, -489.7829998466402, -814.814832351026] import random from datetime import datetime import matplotlib.pyplot as plt class PLA(): def __init__(self, data): self.data = data self.dim = len(self.data[0]['x']) self.w = [0] * self.dim self.choose = random.randint(0, len(self.data)-1) self.update_time = 0 def iterate(self): ''' randomly picks an example (x_n, y_n) in every iteration. updates wt if and only if w_t is incorrect on the example. Inputs: None Outputs: mistake: bool. ''' # index = random.randint(0, len(data)-1) mistake = 1 if sign(dot(self.w, self.data[self.choose]['x'])) != self.data[self.choose]['y'] else 0 if mistake == 0: self.choose = random.randint(0, len(data)-1) mistake = 1 if sign(dot(self.w, self.data[self.choose]['x'])) != self.data[self.choose]['y'] else 0 # print(self.choose) if mistake: self.update(self.choose) return mistake def update(self, index): ''' w_(t+1) = w_t + y_t*x_t Inputs: index: which row need to use for updating weight Outputs: None ''' self.update_time += 1 for i in range(self.dim): self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i] def read_file(): ''' read numbers from 'hw1_train.dat' Inputs: None Outputs: data: dict list. each dict has ('x', 'y') pair x is a list contains 12 numbers. y is either 1 or -1. ''' with open('hw1_train.dat') as fp: data = [] lines = fp.readlines()[:-1] for line in lines: numbers = line.split() x = [ float(i) for i in numbers[:-1] ] x.insert(0,1) y = int(numbers[-1]) data.append({ 'x': x, 'y': y, }) return data def dot(a, b): ''' dot product Inputs: a: a list b: a list. its length must be equal to len(a) Outputs: ans: a number ''' ans = 0 assert len(a)==len(b) for i in range(len(a)): ans += a[i] * b[i] return ans def sign(a): ''' positive || negative Inputs: a: a number Outputs: ans: 1 or -1, means this is positive or negative ''' return 1 if a>0 else -1 if __name__ == '__main__': data = read_file() log = [] for i in range(1000): random.seed(datetime.now().timestamp()) pla = PLA(data) stopping_counter = 0 while 1: mistake = pla.iterate() if mistake: stopping_counter = 0 else: stopping_counter += 1 if stopping_counter == (5*len(data)): break log.append(pla.update_time) sorted_log = sorted(log) print("medium: {}".format((sorted_log[499]+sorted_log[500])/2)) plt.hist(log, bins=100) plt.title("medium: {}".format((sorted_log[499]+sorted_log[500])/2)) plt.savefig('./hw1_12.png') plt.show() print(pla.w)