NTU_HTML/hw3/hw3_12.py
2023-10-26 23:23:08 +08:00

103 lines
3.1 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import time
def generate_data(N):
y = np.random.choice([1, -1], N)
x = np.empty((N, 3))
for index, i in enumerate(y):
if i == 1:
mean = [3, 2]
covariance = [[0.4, 0], [0, 0.4]]
x1, x2 = np.random.multivariate_normal(mean, covariance)
x[index] = np.array([1, x1, x2])
else:
mean = [5, 0]
covariance = [[0.6, 0], [0, 0.6]]
x1, x2 = np.random.multivariate_normal(mean, covariance)
x[index] = np.array([1, x1, x2])
return x, y
def generate_outliers(N):
y = np.ones((N))
x = np.empty((N, 3))
for index, i in enumerate(y):
mean = [0, 6]
covariance = [[0.1, 0], [0, 0.3]]
x1, x2 = np.random.multivariate_normal(mean, covariance)
x[index] = np.array([1, x1, x2])
return x, y
def average_square_error(y, y_hat):
error = (y!=y_hat)
return error.sum()/error.shape[0]
def sigmoid(s):
return 1/(1+np.exp(-s))
def gradient(w, x, y):
y = y.reshape((-1, 1))
theta = sigmoid( - x@w * y )
yx = -y * x
result = theta * yx
grad = np.mean(result, axis=0)
return grad.reshape((3, 1))
if __name__ == '__main__':
linear_regression_errors = []
logistic_regression_errors = []
for times in range(128):
# generate data
np.random.seed(times)
train_x, train_y = generate_data(256) # (256, 3), (256, )
test_x, test_y = generate_data(4096)
outlier_x, outlier_y = generate_outliers(16)
train_x = np.concatenate((train_x, outlier_x), axis=0)
train_y = np.concatenate((train_y, outlier_y), axis=0)
# linear regression
pseudo_inverse_x = np.linalg.pinv(train_x) # (3, 256)
w = pseudo_inverse_x @ train_y # (3)
predict_y = test_x @ w
predict_y = np.sign(predict_y)
error = average_square_error(predict_y, test_y)
linear_regression_errors.append(error)
print(times, error)
# logistic regression
lr = 0.1
T = 500
w0 = np.zeros((3, 1))
for iter in range(T):
grad = gradient(w0, train_x, train_y)
w0 -= lr * grad
w0 = w0.reshape((3))
predict_y = test_x @ w0
predict_y = np.sign(predict_y)
error = average_square_error(predict_y, test_y)
logistic_regression_errors.append(error)
print(times, error)
print()
linear_regression_errors = sorted(linear_regression_errors)
logistic_regression_errors = sorted(logistic_regression_errors)
linear_regression_median = linear_regression_errors[63] + linear_regression_errors[64]
logistic_regression_median = logistic_regression_errors[63] + logistic_regression_errors[64]
plt.scatter(linear_regression_errors, logistic_regression_errors)
plt.xlabel("linear regression error")
plt.xlabel("logistic regression error")
plt.title("linear regression: {}\nlogistic regression: {}".format(linear_regression_median, logistic_regression_median))
plt.savefig("12.png")