From 7322fb4f1529faeb7fb83fbbb05016c5c45a9193 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Thu, 21 Dec 2023 03:25:21 +0800 Subject: [PATCH] feat: complete CART (problem 9) --- hw6/p9.py | 162 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 hw6/p9.py diff --git a/hw6/p9.py b/hw6/p9.py new file mode 100644 index 0000000..29a7c9f --- /dev/null +++ b/hw6/p9.py @@ -0,0 +1,162 @@ + +def load_data(path): + datas = [] + with open(path) as fp: + lines = fp.readlines() + + for line in lines: + tmp_data = [ 0 for _ in range(8) ] + numbers = line.split() + y = int(numbers[0]) + for i in numbers[1:]: + index, value = i.split(':') + index = int(index) + value = float(value) + tmp_data[index-1] = value + x = tmp_data + datas.append({ + 'x': x, + 'y': y, + }) + return datas + +class Node(): + def __init__(self, datas): + self.datas = datas + self.theta = None + self.feature_i = None + + self.value = None + + self.right = None + self.left = None + + def predict(self, data): + if self.value != None: + return self.value + else: + if data['x'][self.feature_i] <= self.theta: + return self.left.predict(data) + else: + return self.right.predict(data) + + def decision_stump(self): + def get_impurity(features_i_y_pair): + if len(features_i_y_pair) == 0: + return 0 + y_bar = sum([ data['y'] for data in features_i_y_pair ]) / len(features_i_y_pair) + + impurity = 0 + for data in features_i_y_pair: + impurity += (data['y']-y_bar) ** 2 + impurity = impurity / len(features_i_y_pair) + return impurity + + different_y = set() + for data in self.datas: + different_y.add(data['y']) + if len(different_y) == 1: + self.value = list(different_y)[0] + return + + min_impurity = 1e9 + best_feature_i, best_theta = -100, -100 + + # find the best feature_i + for feature_i in range(8): + # only get feature_i + # and sorted by feature_i, we want to get the best theta + features_i_y_pair = [ {'feature_i': data['x'][feature_i], 'y': data['y'] } for data in self.datas ] + features_i_y_pair = sorted(features_i_y_pair, key=lambda x: x['feature_i']) + + # if all x are same + if features_i_y_pair[0] == features_i_y_pair[-1]: + continue + thetas = [ (features_i_y_pair[index]['feature_i']+features_i_y_pair[index-1]['feature_i'])/2 for index in range(1, len(features_i_y_pair)) ] + + # find the best theta + for index, theta in enumerate(thetas): + front_data = [] + back_data = [] + + for data in features_i_y_pair: + if data['feature_i'] <= theta: + front_data.append(data) + else: + back_data.append(data) + + # get the impurity when (feature_i, theta) + impurity = len(front_data) * get_impurity(front_data) + len(back_data) * get_impurity(back_data) + if impurity < min_impurity: + min_impurity = impurity + best_feature_i = feature_i + best_theta = theta + + + if best_feature_i != -100: + self.theta = best_theta + self.feature_i = best_feature_i + else: + # print("NO DECISION") + self.value = sum([data['y'] for data in self.datas]) / len(self.datas) + + ''' + print("feature_i: ", self.feature_i) + print("Theta: ", self.theta) + print("Value: ", self.value) + for i in self.datas: + print(" ", i) + ''' + + + def expand(self): + left_data = [] + for data in self.datas: + if data['x'][self.feature_i] <= self.theta: + left_data.append(data) + + right_data = [] + for data in self.datas: + if data['x'][self.feature_i] > self.theta: + right_data.append(data) + + + if len(right_data) == 0 or len(left_data) == 0: + self.value = sum([data['y'] for data in self.datas]) / len(self.datas) + else: + self.left = Node(left_data) + self.right = Node(right_data) + + self.left.decision_stump() + self.right.decision_stump() + + if self.left.theta != None: + self.left.expand() + + if self.right.theta != None: + self.right.expand() + + print("theta: {}, feature_i: {}, value:{}".format(self.theta, self.feature_i, self.value)) + + +def square_error(target, predict): + return (predict-target) ** 2 + + +if __name__ == '__main__': + train = load_data('hw6_train.dat') + test = load_data('hw6_test.dat') + + root = Node(train) + root.decision_stump() + root.expand() + + errors = 0 + for data in test: + predict_y = root.predict(data) + error = square_error(data['y'], predict_y) + errors += error + print("ANS:", errors/len(test)) + + +