NTU_HTML/hw2/hw2_12.py
2023-10-26 23:21:25 +08:00

43 lines
1.1 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
def generate_data(length, noise_prob):
x = np.random.uniform(-1, 1, (length, ))
x = np.sort(x)
y = np.sign(x)
noise_mask = ( np.random.rand(length) <= noise_prob )
y[noise_mask] *= -1
return x, y
def decision_stump(x, y):
theta_seq = np.array([-1] + [(x[i]+x[i+1])/2 for i in range(x.shape[0]-1)])
best_Ein = 1e9
theta_ans = np.random.uniform(-1, 1, 1)[0]
sign_ans = 1 if ( np.random.uniform(-1, 1, 1)[0] > 0 ) else -1
h_of_x = sign_ans * np.sign(x-np.array([theta_ans]*x.shape[0]))
Ein = (h_of_x != y).sum()
return Ein/x.shape[0], theta_ans, sign_ans
Ein_log, Eout_log = [], []
for i in range(2000):
x, y = generate_data(8, 0.1)
# print(x, y)
Ein, theta, sign = decision_stump(x, y)
# print(Ein, theta, sign)
Ein_log.append(Ein)
Eout_log.append(0.5-0.4*sign+0.4*sign*abs(theta))
gap = sorted([ Eout_log[i]-Ein_log[i] for i in range(2000) ])
median = (gap[999]+gap[1000])/2
plt.scatter(Ein_log, Eout_log)
plt.xlabel("Ein")
plt.ylabel("Eout")
plt.title("median: {}".format(median))
plt.savefig("hw2_12.png")
plt.show()