feat: Implement Q-learning in Python
This commit is contained in:
commit
19443ffdf4
77
enviroment.py
Normal file
77
enviroment.py
Normal file
@ -0,0 +1,77 @@
|
||||
import random
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
class Enviroment():
|
||||
def __init__(self):
|
||||
self.board = [ 0 for i in range(9) ]
|
||||
|
||||
self.bot_symbol = 2
|
||||
self.user_symbol = 1
|
||||
|
||||
# self.bot_action()
|
||||
|
||||
def reset(self):
|
||||
self.board = [ 0 for i in range(9) ]
|
||||
|
||||
def show(self):
|
||||
print("┼───┼───┼───┼")
|
||||
for i in range(3):
|
||||
print("│ ", end='')
|
||||
for j in range(3):
|
||||
if self.board[ 3*i + j ] == 0:
|
||||
print(" ", end=' │ ')
|
||||
elif self.board[ 3*i + j ] == 1:
|
||||
print("○", end=' │ ')
|
||||
else:
|
||||
print("✕", end=' │ ')
|
||||
print()
|
||||
print("┼───┼───┼───┼")
|
||||
print()
|
||||
|
||||
def get_available_actions(self):
|
||||
ans = []
|
||||
for i in range(9):
|
||||
if self.board[i] == 0:
|
||||
ans.append(i)
|
||||
return ans
|
||||
|
||||
def get_winner(self):
|
||||
paths = [
|
||||
[0, 1, 2], [3, 4, 5], [6, 7, 8],
|
||||
[0, 3, 6], [1, 4, 7], [2, 5, 8],
|
||||
[0, 4, 8], [2, 4, 6]
|
||||
]
|
||||
for path in paths:
|
||||
x, y, z = path
|
||||
if (self.board[x] == self.board[y]) and (self.board[y] == self.board[z]):
|
||||
return self.board[x]
|
||||
|
||||
return 0
|
||||
|
||||
def state_hash(self):
|
||||
ans = 0
|
||||
for i in range(9):
|
||||
ans += self.board[i] * (3**i)
|
||||
return ans
|
||||
|
||||
def bot_action(self):
|
||||
available_actions = self.get_available_actions()
|
||||
if len(available_actions) > 0:
|
||||
loc = random.choice(available_actions)
|
||||
self.board[loc] = self.bot_symbol
|
||||
|
||||
def action(self, loc):
|
||||
assert loc in self.get_available_actions(), "It's a wrong action"
|
||||
self.board[loc] = self.user_symbol
|
||||
|
||||
winner = self.get_winner() # if != 0: stop
|
||||
if winner == self.user_symbol:
|
||||
reward = 1
|
||||
elif winner == self.bot_symbol:
|
||||
reward = -1
|
||||
else:
|
||||
reward = 0
|
||||
self.bot_action()
|
||||
state = self.state_hash()
|
||||
return state, reward, winner
|
||||
96
main.py
Normal file
96
main.py
Normal file
@ -0,0 +1,96 @@
|
||||
from enviroment import Enviroment
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
env = Enviroment()
|
||||
|
||||
ACTION_NUM = 9
|
||||
STATE_NUM = (3**9)
|
||||
ACTIONS = range(9)
|
||||
EPSILON = 0.9 # epsilon-greedy
|
||||
ALPHA = 0.1 # learning rate
|
||||
LAMBDA = 0.9 # discount factor
|
||||
FRESH_TIME = 0.3
|
||||
EPISODE_NUM = 1e4
|
||||
print("Training EPISODE_NUM: {}\n".format(EPISODE_NUM))
|
||||
|
||||
def chooseAction(state, q_table, actions):
|
||||
state_action = q_table[state]
|
||||
|
||||
random_num = random.random()
|
||||
if random_num > EPSILON or sum(state_action)==0:
|
||||
return random.choice(actions)
|
||||
else:
|
||||
available_actions = [ state_action[i] for i in actions ]
|
||||
choise = np.argmax(available_actions)
|
||||
return actions[choise]
|
||||
|
||||
def getEstimateSReward(env, table, state):
|
||||
state_action = table[state]
|
||||
actions = env.get_available_actions()
|
||||
available_actions = [ state_action[i] for i in actions ]
|
||||
reward = np.max(available_actions)
|
||||
return reward
|
||||
|
||||
def evaluate(env, table, times):
|
||||
counter = 0
|
||||
for episode in range(times):
|
||||
env.reset()
|
||||
S = env.state_hash()
|
||||
while 1:
|
||||
available_actions = env.get_available_actions()
|
||||
action = chooseAction(S, table, available_actions)
|
||||
|
||||
estimate_R = table[S][action]
|
||||
S_, R, winner = env.action(action)
|
||||
|
||||
if winner != 0 or len(available_actions) == 1:
|
||||
real_R = R
|
||||
else:
|
||||
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
|
||||
|
||||
S = S_
|
||||
|
||||
if winner != 0 or len(available_actions) == 1:
|
||||
if winner == env.user_symbol:
|
||||
counter += 1
|
||||
break
|
||||
print("{}/{} winning percentage: {}%\n".format(counter, times, counter/times*100))
|
||||
|
||||
table = [ [ 0 for i in range(ACTION_NUM)] for j in range(STATE_NUM) ]
|
||||
table[0][6] = 0.9
|
||||
table[0][7] = 1
|
||||
|
||||
env = Enviroment()
|
||||
|
||||
|
||||
print("Before Q-Learning")
|
||||
evaluate(env, table, 10000)
|
||||
|
||||
for episode in range(int(EPISODE_NUM)):
|
||||
env.reset()
|
||||
S = env.state_hash()
|
||||
# print(S)
|
||||
while 1:
|
||||
available_actions = env.get_available_actions()
|
||||
action = chooseAction(S, table, available_actions)
|
||||
|
||||
estimate_R = table[S][action]
|
||||
S_, R, winner = env.action(action)
|
||||
# env.show()
|
||||
|
||||
if winner != 0 or len(available_actions) == 1:
|
||||
real_R = R
|
||||
else:
|
||||
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
|
||||
|
||||
table[S][action] += ALPHA * (real_R - estimate_R)
|
||||
S = S_
|
||||
|
||||
if winner != 0 or len(available_actions) == 1:
|
||||
break
|
||||
# print("\n\n")
|
||||
# print("==============================")
|
||||
# print(counter)
|
||||
print("After Q-Learning")
|
||||
evaluate(env, table, 10000)
|
||||
Loading…
Reference in New Issue
Block a user