fix: plot unsorted data & mistake 0/1 error
This commit is contained in:
parent
b3ddd2d11e
commit
0c07cd70e1
@ -22,7 +22,7 @@ def generate_data(N):
|
|||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
def average_square_error(y, y_hat):
|
def average_square_error(y, y_hat):
|
||||||
error = (y==y_hat)
|
error = (y!=y_hat)
|
||||||
return error.sum()/error.shape[0]
|
return error.sum()/error.shape[0]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -42,10 +42,12 @@ if __name__ == '__main__':
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
print(times, error)
|
print(times, error)
|
||||||
|
|
||||||
errors = sorted(errors)
|
sorted_errors = sorted(errors)
|
||||||
median = ( errors[63] + errors[64] ) / 2
|
median = ( sorted_errors[63] + sorted_errors[64] ) / 2
|
||||||
|
|
||||||
plt.hist(errors, bins=10)
|
plt.hist(errors, bins=10)
|
||||||
plt.xlabel("Ein")
|
plt.xlabel("Ein")
|
||||||
plt.title("median: {}".format(median))
|
plt.title("median: {}".format(median))
|
||||||
plt.savefig("10.png")
|
plt.savefig("10.png")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -77,13 +77,13 @@ if __name__ == '__main__':
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
linear_regression_errors = sorted(linear_regression_errors)
|
sorted_linear_regression_errors = sorted(linear_regression_errors)
|
||||||
logistic_regression_errors = sorted(logistic_regression_errors)
|
sorted_logistic_regression_errors = sorted(logistic_regression_errors)
|
||||||
linear_regression_median = linear_regression_errors[63] + linear_regression_errors[64]
|
linear_regression_median = sorted_linear_regression_errors[63] + sorted_linear_regression_errors[64]
|
||||||
logistic_regression_median = logistic_regression_errors[63] + logistic_regression_errors[64]
|
logistic_regression_median = sorted_logistic_regression_errors[63] + sorted_logistic_regression_errors[64]
|
||||||
|
|
||||||
plt.scatter(linear_regression_errors, logistic_regression_errors)
|
plt.scatter(linear_regression_errors, logistic_regression_errors)
|
||||||
plt.xlabel("linear regression error")
|
plt.xlabel("linear regression error")
|
||||||
plt.xlabel("logistic regression error")
|
plt.xlabel("logistic regression error")
|
||||||
plt.title("linear regression: {}\nlogistic regression: {}".format(linear_regression_median, logistic_regression_median))
|
plt.title("linear regression: {}\nlogistic regression: {}".format(linear_regression_median, logistic_regression_median))
|
||||||
plt.savefig("11.png")
|
plt.savefig("11.png")
|
||||||
|
|||||||
@ -90,14 +90,13 @@ if __name__ == '__main__':
|
|||||||
print(times, error)
|
print(times, error)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
sorted_linear_regression_errors = sorted(linear_regression_errors)
|
||||||
linear_regression_errors = sorted(linear_regression_errors)
|
sorted_logistic_regression_errors = sorted(logistic_regression_errors)
|
||||||
logistic_regression_errors = sorted(logistic_regression_errors)
|
linear_regression_median = sorted_linear_regression_errors[63] + sorted_linear_regression_errors[64]
|
||||||
linear_regression_median = linear_regression_errors[63] + linear_regression_errors[64]
|
logistic_regression_median = sorted_logistic_regression_errors[63] + sorted_logistic_regression_errors[64]
|
||||||
logistic_regression_median = logistic_regression_errors[63] + logistic_regression_errors[64]
|
|
||||||
|
|
||||||
plt.scatter(linear_regression_errors, logistic_regression_errors)
|
plt.scatter(linear_regression_errors, logistic_regression_errors)
|
||||||
plt.xlabel("linear regression error")
|
plt.xlabel("linear regression error")
|
||||||
plt.xlabel("logistic regression error")
|
plt.xlabel("logistic regression error")
|
||||||
plt.title("linear regression: {}\nlogistic regression: {}".format(linear_regression_median, logistic_regression_median))
|
plt.title("linear regression: {}\nlogistic regression: {}".format(linear_regression_median, logistic_regression_median))
|
||||||
plt.savefig("12.png")
|
plt.savefig("12.png")
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import numpy as np
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def generate_data(N):
|
def generate_data(N):
|
||||||
y = np.random.choice([1, -1], N)
|
y = np.random.choice([1, -1], N)
|
||||||
|
|
||||||
@ -42,10 +41,10 @@ if __name__ == '__main__':
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
print(times, error)
|
print(times, error)
|
||||||
|
|
||||||
errors = sorted(errors)
|
sorted_errors = sorted(errors)
|
||||||
median = ( errors[63] + errors[64] ) / 2
|
median = ( sorted_errors[63] + sorted_errors[64] ) / 2
|
||||||
|
|
||||||
plt.hist(errors, bins=10)
|
plt.hist(errors, bins=10)
|
||||||
plt.xlabel("Ein")
|
plt.xlabel("Ein")
|
||||||
plt.title("median: {}".format(median))
|
plt.title("median: {}".format(median))
|
||||||
plt.savefig("9.png")
|
plt.savefig("9.png")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user