diff --git a/hw6/p9.py b/hw6/p9.py index 29a7c9f..116202d 100644 --- a/hw6/p9.py +++ b/hw6/p9.py @@ -151,12 +151,19 @@ if __name__ == '__main__': root.decision_stump() root.expand() + errors = 0 + for data in train: + predict_y = root.predict(data) + error = square_error(data['y'], predict_y) + errors += error + print("E_in:", errors/len(test)) + errors = 0 for data in test: predict_y = root.predict(data) error = square_error(data['y'], predict_y) errors += error - print("ANS:", errors/len(test)) + print("E_out:", errors/len(test))