From 19443ffdf454805fbf7591d2f8e419b842ee517c Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Fri, 19 May 2023 23:51:41 +0800 Subject: [PATCH] feat: Implement Q-learning in Python --- enviroment.py | 77 +++++++++++++++++++++++++++++++++++++++++ main.py | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 enviroment.py create mode 100644 main.py diff --git a/enviroment.py b/enviroment.py new file mode 100644 index 0000000..8033726 --- /dev/null +++ b/enviroment.py @@ -0,0 +1,77 @@ +import random +import pandas as pd +import numpy as np + +class Enviroment(): + def __init__(self): + self.board = [ 0 for i in range(9) ] + + self.bot_symbol = 2 + self.user_symbol = 1 + + # self.bot_action() + + def reset(self): + self.board = [ 0 for i in range(9) ] + + def show(self): + print("┼───┼───┼───┼") + for i in range(3): + print("│ ", end='') + for j in range(3): + if self.board[ 3*i + j ] == 0: + print(" ", end=' │ ') + elif self.board[ 3*i + j ] == 1: + print("○", end=' │ ') + else: + print("✕", end=' │ ') + print() + print("┼───┼───┼───┼") + print() + + def get_available_actions(self): + ans = [] + for i in range(9): + if self.board[i] == 0: + ans.append(i) + return ans + + def get_winner(self): + paths = [ + [0, 1, 2], [3, 4, 5], [6, 7, 8], + [0, 3, 6], [1, 4, 7], [2, 5, 8], + [0, 4, 8], [2, 4, 6] + ] + for path in paths: + x, y, z = path + if (self.board[x] == self.board[y]) and (self.board[y] == self.board[z]): + return self.board[x] + + return 0 + + def state_hash(self): + ans = 0 + for i in range(9): + ans += self.board[i] * (3**i) + return ans + + def bot_action(self): + available_actions = self.get_available_actions() + if len(available_actions) > 0: + loc = random.choice(available_actions) + self.board[loc] = self.bot_symbol + + def action(self, loc): + assert loc in self.get_available_actions(), "It's a wrong action" + self.board[loc] = self.user_symbol + + winner = self.get_winner() # if != 0: stop + if winner == self.user_symbol: + reward = 1 + elif winner == self.bot_symbol: + reward = -1 + else: + reward = 0 + self.bot_action() + state = self.state_hash() + return state, reward, winner \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..f60fe68 --- /dev/null +++ b/main.py @@ -0,0 +1,96 @@ +from enviroment import Enviroment +import random +import numpy as np + +env = Enviroment() + +ACTION_NUM = 9 +STATE_NUM = (3**9) +ACTIONS = range(9) +EPSILON = 0.9 # epsilon-greedy +ALPHA = 0.1 # learning rate +LAMBDA = 0.9 # discount factor +FRESH_TIME = 0.3 +EPISODE_NUM = 1e4 +print("Training EPISODE_NUM: {}\n".format(EPISODE_NUM)) + +def chooseAction(state, q_table, actions): + state_action = q_table[state] + + random_num = random.random() + if random_num > EPSILON or sum(state_action)==0: + return random.choice(actions) + else: + available_actions = [ state_action[i] for i in actions ] + choise = np.argmax(available_actions) + return actions[choise] + +def getEstimateSReward(env, table, state): + state_action = table[state] + actions = env.get_available_actions() + available_actions = [ state_action[i] for i in actions ] + reward = np.max(available_actions) + return reward + +def evaluate(env, table, times): + counter = 0 + for episode in range(times): + env.reset() + S = env.state_hash() + while 1: + available_actions = env.get_available_actions() + action = chooseAction(S, table, available_actions) + + estimate_R = table[S][action] + S_, R, winner = env.action(action) + + if winner != 0 or len(available_actions) == 1: + real_R = R + else: + real_R = R + LAMBDA * getEstimateSReward(env, table, S_) + + S = S_ + + if winner != 0 or len(available_actions) == 1: + if winner == env.user_symbol: + counter += 1 + break + print("{}/{} winning percentage: {}%\n".format(counter, times, counter/times*100)) + +table = [ [ 0 for i in range(ACTION_NUM)] for j in range(STATE_NUM) ] +table[0][6] = 0.9 +table[0][7] = 1 + +env = Enviroment() + + +print("Before Q-Learning") +evaluate(env, table, 10000) + +for episode in range(int(EPISODE_NUM)): + env.reset() + S = env.state_hash() + # print(S) + while 1: + available_actions = env.get_available_actions() + action = chooseAction(S, table, available_actions) + + estimate_R = table[S][action] + S_, R, winner = env.action(action) + # env.show() + + if winner != 0 or len(available_actions) == 1: + real_R = R + else: + real_R = R + LAMBDA * getEstimateSReward(env, table, S_) + + table[S][action] += ALPHA * (real_R - estimate_R) + S = S_ + + if winner != 0 or len(available_actions) == 1: + break +# print("\n\n") +# print("==============================") +# print(counter) +print("After Q-Learning") +evaluate(env, table, 10000) \ No newline at end of file