Q-learning-in-Python/main.py

96 lines
2.7 KiB
Python

from enviroment import Enviroment
import random
import numpy as np
env = Enviroment()
ACTION_NUM = 9
STATE_NUM = (3**9)
ACTIONS = range(9)
EPSILON = 0.9 # epsilon-greedy
ALPHA = 0.1 # learning rate
LAMBDA = 0.9 # discount factor
FRESH_TIME = 0.3
EPISODE_NUM = 1e4
print("Training EPISODE_NUM: {}\n".format(EPISODE_NUM))
def chooseAction(state, q_table, actions):
state_action = q_table[state]
random_num = random.random()
if random_num > EPSILON or sum(state_action)==0:
return random.choice(actions)
else:
available_actions = [ state_action[i] for i in actions ]
choise = np.argmax(available_actions)
return actions[choise]
def getEstimateSReward(env, table, state):
state_action = table[state]
actions = env.get_available_actions()
available_actions = [ state_action[i] for i in actions ]
reward = np.max(available_actions)
return reward
def evaluate(env, table, times):
counter = 0
for episode in range(times):
env.reset()
S = env.state_hash()
while 1:
available_actions = env.get_available_actions()
action = chooseAction(S, table, available_actions)
estimate_R = table[S][action]
S_, R, winner = env.action(action)
if winner != 0 or len(available_actions) == 1:
real_R = R
else:
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
S = S_
if winner != 0 or len(available_actions) == 1:
if winner == env.user_symbol:
counter += 1
break
print("{}/{} winning percentage: {}%\n".format(counter, times, counter/times*100))
table = [ [ 0 for i in range(ACTION_NUM)] for j in range(STATE_NUM) ]
table[0][6] = 0.9
table[0][7] = 1
env = Enviroment()
print("Before Q-Learning")
evaluate(env, table, 10000)
for episode in range(int(EPISODE_NUM)):
env.reset()
S = env.state_hash()
# print(S)
while 1:
available_actions = env.get_available_actions()
action = chooseAction(S, table, available_actions)
estimate_R = table[S][action]
S_, R, winner = env.action(action)
# env.show()
if winner != 0 or len(available_actions) == 1:
real_R = R
else:
real_R = R + LAMBDA * getEstimateSReward(env, table, S_)
table[S][action] += ALPHA * (real_R - estimate_R)
S = S_
if winner != 0 or len(available_actions) == 1:
break
# print("\n\n")
# print("==============================")
# print(counter)
print("After Q-Learning")
evaluate(env, table, 10000)