# 25428.0 25891.5 26446.5 # [1002.1399999999994, 223.21665372403237, -1291.848080071952, 1515.5600276924624, -1404.8070879687737, -887.5562485642265, -2129.790874533273, 1268.8912058979067, -893.1255882425288, -358.771941045434, 1161.2662431910446, -499.8519798398823, -806.2229919825804] import random from datetime import datetime import matplotlib.pyplot as plt class PLA(): def __init__(self, data): self.data = data self.dim = len(self.data[0]['x']) self.w = [0] * self.dim def iterate(self): ''' randomly picks an example (x_n, y_n) in every iteration. updates wt if and only if w_t is incorrect on the example. Inputs: None Outputs: mistake: bool. ''' index = random.randint(0, len(data)-1) mistake = 1 if sign(dot(self.w, self.data[index]['x'])) != self.data[index]['y'] else 0 if mistake: self.update(index) return mistake def update(self, index): ''' w_(t+1) = w_t + y_t*x_t Inputs: index: which row need to use for updating weight Outputs: None ''' for i in range(self.dim): self.w[i] = self.w[i] + self.data[index]['y'] * self.data[index]['x'][i] def read_file(): ''' read numbers from 'hw1_train.dat' Inputs: None Outputs: data: dict list. each dict has ('x', 'y') pair x is a list contains 12 numbers. y is either 1 or -1. ''' with open('hw1_train.dat') as fp: data = [] lines = fp.readlines()[:-1] for line in lines: numbers = line.split() x = [ float(i) for i in numbers[:-1] ] x.insert(0,11.26) y = int(numbers[-1]) data.append({ 'x': x, 'y': y, }) return data def dot(a, b): ''' dot product Inputs: a: a list b: a list. its length must be equal to len(a) Outputs: ans: a number ''' ans = 0 assert len(a)==len(b) for i in range(len(a)): ans += a[i] * b[i] return ans def sign(a): ''' positive || negative Inputs: a: a number Outputs: ans: 1 or -1, means this is positive or negative ''' return 1 if a>0 else -1 if __name__ == '__main__': data = read_file() log = [] for i in range(1000): random.seed(datetime.now().timestamp()) pla = PLA(data) counter = 0 stopping_counter = 0 while 1: mistake = pla.iterate() counter += 1 if mistake: stopping_counter = 0 else: stopping_counter += 1 if stopping_counter == (5*len(data)): break log.append(counter) sorted_log = sorted(log) print("medium: {}".format((sorted_log[499]+sorted_log[500])/2)) plt.hist(log, bins=100) plt.savefig('./hw1_11.png') plt.show() print(pla.w)