feat: Policy Gradient, Rewards Delay

This commit is contained in:
Ting-Jun Wang 2022-07-02 03:13:05 +08:00
parent 7c07b4aa96
commit 1366659a41
Signed by: snsd0805
GPG Key ID: 8DB0D22BC1217D33
4 changed files with 193 additions and 72 deletions

View File

@ -1,20 +1,13 @@
import torch
from torch.optim import Adam
from torch.distributions import Categorical
from zmq import device
from enviroment import TetrisEnviroment
from network import TetrisRLModel
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("use:", device)
PATH = "model.h5"
class TetrisRLAgent():
def __init__(self) -> None:
def __init__(self, device) -> None:
self.device = device
self.model = TetrisRLModel()
self.model = self.model.to(device)
self.model = self.model.to(self.device)
self.optim = Adam(self.model.parameters(), lr=0.001)
def sample(self, observation):
@ -25,7 +18,7 @@ class TetrisRLAgent():
return action, log_prob
def learn(self, rewards, log_probes):
rewards = rewards.to(device)
rewards = rewards.to(self.device)
loss = ((-log_probes * rewards)).sum()
self.optim.zero_grad()
@ -39,58 +32,3 @@ class TetrisRLAgent():
}
torch.save(Agent_Dict, PATH)
def showGame(views:list, score:int) -> None:
for i in range(20):
print(str(i).rjust(2), end=' ')
for j in range(10):
if views[i][j]:
print('', end='')
else:
print('', end='')
print()
print("Score:", score)
print()
env = TetrisEnviroment()
agent = TetrisRLAgent()
avgTotalRewards = []
for batch in range(100000):
log_probs, rewards = [], []
total_rewards = []
for episode in range(5):
total_reward = 0
pixel, reward, done, info = env.reset()
while 1:
showGame(pixel, total_reward)
pixel = torch.tensor(pixel, dtype=torch.float32)
blockType = F.one_hot(torch.tensor(info[2]), 7)
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
observation = observation.to(device)
action, log_prob = agent.sample(observation)
pixel, reward, done, info = env.step(action)
rewards.append(reward)
log_probs.append(log_prob)
total_reward += reward
if done:
total_rewards.append(total_reward)
break
avgTotalReward = sum(total_rewards) / len(total_rewards)
avgTotalRewards.append(avgTotalReward)
rewards = torch.tensor(rewards)
log_probs = torch.stack(log_probs)
print(rewards.shape)
print(log_probs.shape)
agent.learn(rewards, log_probs)
agent.save(PATH)

View File

@ -4,10 +4,15 @@ class TetrisEnviroment():
def __init__(self) -> None:
self.game = TetrisGame()
self.score = 0
self.pastHeights = 0
self.pastHoles = 0
def reset(self):
self.game.reset()
self.score = 0
self.pastHeights = 0
self.pastHoles = 0
return self.game.view(), 0, self.game.done, \
(self.game.block.x, self.game.block.y, self.game.block.block_id)
@ -39,11 +44,48 @@ class TetrisEnviroment():
for i in range(3):
self.game.action('f')
self.game.action('d')
fallStatus = self.game.action('d')
deltaScore = self.game.score - self.score
# 開始計算 rewards
# 1. 消除 line
completeLines = self.game.score - self.score
self.score = self.game.score
return self.game.view(), deltaScore, self.game.done, \
(self.game.block.x, self.game.block.y, self.game.block.block_id)
# observation, reward, done, info(block location)
# 2. block 總高度
heights = self.game.getAggregateHeight()
# 3. hole 數量
holes = self.game.getHoleNumber()
# 根據權重計算 rewards
# https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
rewards = -0.5*(heights-self.pastHeights) + 2*completeLines + -0.3*(holes-self.pastHoles)
if fallStatus == False:
rewards += 1
self.pastHeights = heights
self.pastHoles = holes
return self.game.view(), rewards, self.game.done, \
(self.game.block.x, self.game.block.y, self.game.block.block_id, completeLines)
# observation, reward, done, info(block location)
# env = TetrisEnviroment()
# pixel, reward, done, info = env.reset()
# while 1:
# for i in range(20):
# print(str(i).rjust(2), end=' ')
# for j in range(10):
# if pixel[i][j]:
# print('■', end='')
# else:
# print('□', end='')
# print()
# action = int(input("Action: "))
# pixel, reward, done, info = env.step(action)
# print(pixel, reward, done, info)
# print("Rewards: ", reward)
# if done:
# break

56
game.py
View File

@ -144,7 +144,8 @@ class TetrisGame():
def action(self, mode):
if mode == 'd':
try:
self.block.fall()
fallStatus = self.block.fall()
return fallStatus
except:
print("GAME OVER")
self.done = True
@ -156,9 +157,62 @@ class TetrisGame():
self.block.rotate()
self.score += self.board.checkScore()
return True
def view(self):
return view(self.board, self.block)
def getAggregateHeight(self) -> int:
height = 0
for col in range(10):
for row in range(20):
if self.board.block[row][col] == 1:
if (20-row) > height:
height = (20-row)
break
return height
def getHoleNumber(self) -> int:
nowHoleNumber = 0
def checkLoc(row, col) -> bool:
if row < 0 or row > 19:
return False
if col < 0 or col > 9:
return False
return True
def updateCopyBoard(originBoard, flagBoard, row, col):
flagBoard[row][col] = nowHoleNumber
if checkLoc(row+1, col):
if originBoard[row+1][col] == 0 and flagBoard[row+1][col] == -1:
# print("CALL:", row, col, row+1, col)
updateCopyBoard(originBoard, flagBoard, row+1, col)
if checkLoc(row-1, col):
if originBoard[row-1][col] == 0 and flagBoard[row-1][col] == -1:
# print("CALL:", row, col, row-1, col)
updateCopyBoard(originBoard, flagBoard, row-1, col)
if checkLoc(row, col+1):
if originBoard[row][col+1] == 0 and flagBoard[row][col+1] == -1:
# print("CALL:", row, col, row, col+1)
updateCopyBoard(originBoard, flagBoard, row, col+1)
if checkLoc(row, col-1):
if originBoard[row][col-1] == 0 and flagBoard[row][col-1] == -1:
# print("CALL:", row, col, row, col-1)
updateCopyBoard(originBoard, flagBoard, row, col-1)
copyBoard = [[-1 for _ in range(10)] for _ in range(20)]
for row in range(20):
for col in range(10):
if self.board.block[row][col] == 0 and copyBoard[row][col] == -1:
updateCopyBoard(self.board.block, copyBoard, row, col)
nowHoleNumber += 1
# for row in range(20):
# for col in range(10):
# print(copyBoard[row][col], end=' ')
# print()
# print()
return nowHoleNumber-1
def view(board:Board, block:Block) -> list:
views = []

87
train.py Normal file
View File

@ -0,0 +1,87 @@
from enviroment import TetrisEnviroment
from agent import TetrisRLAgent
import torch
import torch.nn.functional as F
from enviroment import TetrisEnviroment
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("use:", device)
PATH = "model.h5"
def showGame(views:list, reward:int, score:int, completeLines:int) -> None:
for i in range(20):
print(str(i).rjust(2), end=' ')
for j in range(10):
if views[i][j]:
print('', end='')
else:
print('', end='')
print()
print("reward:", reward)
print("total_reward:", score)
print("total_line :", completeLines)
print()
env = TetrisEnviroment()
agent = TetrisRLAgent(device)
avgTotalRewards = []
for batch in range(100000):
log_probs, rewards = [], []
total_rewards = []
for episode in range(5):
total_reward = 0
total_lines = 0
pixel, reward, done, info = env.reset()
while 1:
showGame(pixel, reward, total_reward, total_lines)
pixel = torch.tensor(pixel, dtype=torch.float32)
blockType = F.one_hot(torch.tensor(info[2]), 7)
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
observation = observation.to(device)
action, log_prob = agent.sample(observation)
pixel, reward, done, info = env.step(action)
rewards.append(reward)
log_probs.append(log_prob)
total_reward += reward
total_lines += info[3]
if done:
total_rewards.append(total_reward)
break
print("TOTAL REWARDS", total_rewards)
avgTotalReward = sum(total_rewards) / len(total_rewards)
avgTotalRewards.append(avgTotalReward)
ALPHA = 0.98
delayRewards = []
for start in range(len(rewards)):
ans = rewards[start]
weight = ALPHA
for i in range(start+1, len(rewards)):
ans += (weight * rewards[i])
weight *= ALPHA
delayRewards.append(ans)
rewards = torch.tensor(delayRewards)
log_probs = torch.stack(log_probs)
print(batch)
print(rewards.shape)
print(log_probs.shape)
print(avgTotalReward)
# print(rewards.tolist())
agent.learn(rewards, log_probs)
# time.sleep(5)
agent.save(PATH)