feat: Implement Q-learning in Python
This commit is contained in:
commit
19443ffdf4
77
enviroment.py
Normal file
77
enviroment.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Enviroment():
|
||||||
|
def __init__(self):
|
||||||
|
self.board = [ 0 for i in range(9) ]
|
||||||
|
|
||||||
|
self.bot_symbol = 2
|
||||||
|
self.user_symbol = 1
|
||||||
|
|
||||||
|
# self.bot_action()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.board = [ 0 for i in range(9) ]
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
print("┼───┼───┼───┼")
|
||||||
|
for i in range(3):
|
||||||
|
print("│ ", end='')
|
||||||
|
for j in range(3):
|
||||||
|
if self.board[ 3*i + j ] == 0:
|
||||||
|
print(" ", end=' │ ')
|
||||||
|
elif self.board[ 3*i + j ] == 1:
|
||||||
|
print("○", end=' │ ')
|
||||||
|
else:
|
||||||
|
print("✕", end=' │ ')
|
||||||
|
print()
|
||||||
|
print("┼───┼───┼───┼")
|
||||||
|
print()
|
||||||
|
|
||||||
|
def get_available_actions(self):
|
||||||
|
ans = []
|
||||||
|
for i in range(9):
|
||||||
|
if self.board[i] == 0:
|
||||||
|
ans.append(i)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def get_winner(self):
|
||||||
|
paths = [
|
||||||
|
[0, 1, 2], [3, 4, 5], [6, 7, 8],
|
||||||
|
[0, 3, 6], [1, 4, 7], [2, 5, 8],
|
||||||
|
[0, 4, 8], [2, 4, 6]
|
||||||
|
]
|
||||||
|
for path in paths:
|
||||||
|
x, y, z = path
|
||||||
|
if (self.board[x] == self.board[y]) and (self.board[y] == self.board[z]):
|
||||||
|
return self.board[x]
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def state_hash(self):
|
||||||
|
ans = 0
|
||||||
|
for i in range(9):
|
||||||
|
ans += self.board[i] * (3**i)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def bot_action(self):
|
||||||
|
available_actions = self.get_available_actions()
|
||||||
|
if len(available_actions) > 0:
|
||||||
|
loc = random.choice(available_actions)
|
||||||
|
self.board[loc] = self.bot_symbol
|
||||||
|
|
||||||
|
def action(self, loc):
|
||||||
|
assert loc in self.get_available_actions(), "It's a wrong action"
|
||||||
|
self.board[loc] = self.user_symbol
|
||||||
|
|
||||||
|
winner = self.get_winner() # if != 0: stop
|
||||||
|
if winner == self.user_symbol:
|
||||||
|
reward = 1
|
||||||
|
elif winner == self.bot_symbol:
|
||||||
|
reward = -1
|
||||||
|
else:
|
||||||
|
reward = 0
|
||||||
|
self.bot_action()
|
||||||
|
state = self.state_hash()
|
||||||
|
return state, reward, winner
|
||||||
96
main.py
Normal file
96
main.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from enviroment import Enviroment
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
env = Enviroment()
|
||||||
|
|
||||||
|
ACTION_NUM = 9
|
||||||
|
STATE_NUM = (3**9)
|
||||||
|
ACTIONS = range(9)
|
||||||
|
EPSILON = 0.9 # epsilon-greedy
|
||||||
|
ALPHA = 0.1 # learning rate
|
||||||
|
LAMBDA = 0.9 # discount factor
|
||||||
|
FRESH_TIME = 0.3
|
||||||
|
EPISODE_NUM = 1e4
|
||||||
|
print("Training EPISODE_NUM: {}\n".format(EPISODE_NUM))
|
||||||
|
|
||||||
|
def chooseAction(state, q_table, actions):
|
||||||
|
state_action = q_table[state]
|
||||||
|
|
||||||
|
random_num = random.random()
|
||||||
|
if random_num > EPSILON or sum(state_action)==0:
|
||||||
|
return random.choice(actions)
|
||||||
|
else:
|
||||||
|
available_actions = [ state_action[i] for i in actions ]
|
||||||
|
choise = np.argmax(available_actions)
|
||||||
|
return actions[choise]
|
||||||
|
|
||||||
|
def getEstimateSReward(env, table, state):
|
||||||
|
state_action = table[state]
|
||||||
|
actions = env.get_available_actions()
|
||||||
|
available_actions = [ state_action[i] for i in actions ]
|
||||||
|
reward = np.max(available_actions)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def evaluate(env, table, times):
|
||||||
|
counter = 0
|
||||||
|
for episode in range(times):
|
||||||
|
env.reset()
|
||||||
|
S = env.state_hash()
|
||||||
|
while 1:
|
||||||
|
available_actions = env.get_available_actions()
|
||||||
|
action = chooseAction(S, table, available_actions)
|
||||||
|
|
||||||
|
estimate_R = table[S][action]
|
||||||
|
S_, R, winner = env.action(action)
|
||||||
|
|
||||||
|
if winner != 0 or len(available_actions) == 1:
|
||||||
|
real_R = R
|
||||||
|
else:
|
||||||
|
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
|
||||||
|
|
||||||
|
S = S_
|
||||||
|
|
||||||
|
if winner != 0 or len(available_actions) == 1:
|
||||||
|
if winner == env.user_symbol:
|
||||||
|
counter += 1
|
||||||
|
break
|
||||||
|
print("{}/{} winning percentage: {}%\n".format(counter, times, counter/times*100))
|
||||||
|
|
||||||
|
table = [ [ 0 for i in range(ACTION_NUM)] for j in range(STATE_NUM) ]
|
||||||
|
table[0][6] = 0.9
|
||||||
|
table[0][7] = 1
|
||||||
|
|
||||||
|
env = Enviroment()
|
||||||
|
|
||||||
|
|
||||||
|
print("Before Q-Learning")
|
||||||
|
evaluate(env, table, 10000)
|
||||||
|
|
||||||
|
for episode in range(int(EPISODE_NUM)):
|
||||||
|
env.reset()
|
||||||
|
S = env.state_hash()
|
||||||
|
# print(S)
|
||||||
|
while 1:
|
||||||
|
available_actions = env.get_available_actions()
|
||||||
|
action = chooseAction(S, table, available_actions)
|
||||||
|
|
||||||
|
estimate_R = table[S][action]
|
||||||
|
S_, R, winner = env.action(action)
|
||||||
|
# env.show()
|
||||||
|
|
||||||
|
if winner != 0 or len(available_actions) == 1:
|
||||||
|
real_R = R
|
||||||
|
else:
|
||||||
|
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
|
||||||
|
|
||||||
|
table[S][action] += ALPHA * (real_R - estimate_R)
|
||||||
|
S = S_
|
||||||
|
|
||||||
|
if winner != 0 or len(available_actions) == 1:
|
||||||
|
break
|
||||||
|
# print("\n\n")
|
||||||
|
# print("==============================")
|
||||||
|
# print(counter)
|
||||||
|
print("After Q-Learning")
|
||||||
|
evaluate(env, table, 10000)
|
||||||
Loading…
Reference in New Issue
Block a user