diff --git a/hw2/hw2_12.py b/hw2/hw2_12.py index 75cc8bc..de80c61 100644 --- a/hw2/hw2_12.py +++ b/hw2/hw2_12.py @@ -14,7 +14,7 @@ def decision_stump(x, y): best_Ein = 1e9 theta_ans = np.random.uniform(-1, 1, 1)[0] - sign_ans = ( np.random.uniform(-1, 1, 1)[0] > 0 ) + sign_ans = 1 if ( np.random.uniform(-1, 1, 1)[0] > 0 ) else -1 h_of_x = sign_ans * np.sign(x-np.array([theta_ans]*x.shape[0])) Ein = (h_of_x != y).sum()