feat: Policy Gradient, Rewards Delay
This commit is contained in:
parent
7c07b4aa96
commit
1366659a41
70
agent.py
70
agent.py
@ -1,20 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
from zmq import device
|
|
||||||
from enviroment import TetrisEnviroment
|
|
||||||
from network import TetrisRLModel
|
from network import TetrisRLModel
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
print("use:", device)
|
|
||||||
|
|
||||||
PATH = "model.h5"
|
|
||||||
|
|
||||||
class TetrisRLAgent():
|
class TetrisRLAgent():
|
||||||
def __init__(self) -> None:
|
def __init__(self, device) -> None:
|
||||||
|
self.device = device
|
||||||
self.model = TetrisRLModel()
|
self.model = TetrisRLModel()
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(self.device)
|
||||||
self.optim = Adam(self.model.parameters(), lr=0.001)
|
self.optim = Adam(self.model.parameters(), lr=0.001)
|
||||||
|
|
||||||
def sample(self, observation):
|
def sample(self, observation):
|
||||||
@ -25,7 +18,7 @@ class TetrisRLAgent():
|
|||||||
return action, log_prob
|
return action, log_prob
|
||||||
|
|
||||||
def learn(self, rewards, log_probes):
|
def learn(self, rewards, log_probes):
|
||||||
rewards = rewards.to(device)
|
rewards = rewards.to(self.device)
|
||||||
loss = ((-log_probes * rewards)).sum()
|
loss = ((-log_probes * rewards)).sum()
|
||||||
|
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
@ -39,58 +32,3 @@ class TetrisRLAgent():
|
|||||||
}
|
}
|
||||||
torch.save(Agent_Dict, PATH)
|
torch.save(Agent_Dict, PATH)
|
||||||
|
|
||||||
def showGame(views:list, score:int) -> None:
|
|
||||||
for i in range(20):
|
|
||||||
print(str(i).rjust(2), end=' ')
|
|
||||||
for j in range(10):
|
|
||||||
if views[i][j]:
|
|
||||||
print('■', end='')
|
|
||||||
else:
|
|
||||||
print('□', end='')
|
|
||||||
print()
|
|
||||||
print("Score:", score)
|
|
||||||
print()
|
|
||||||
|
|
||||||
env = TetrisEnviroment()
|
|
||||||
agent = TetrisRLAgent()
|
|
||||||
|
|
||||||
avgTotalRewards = []
|
|
||||||
for batch in range(100000):
|
|
||||||
|
|
||||||
log_probs, rewards = [], []
|
|
||||||
total_rewards = []
|
|
||||||
|
|
||||||
for episode in range(5):
|
|
||||||
total_reward = 0
|
|
||||||
|
|
||||||
pixel, reward, done, info = env.reset()
|
|
||||||
while 1:
|
|
||||||
showGame(pixel, total_reward)
|
|
||||||
|
|
||||||
pixel = torch.tensor(pixel, dtype=torch.float32)
|
|
||||||
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
|
||||||
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
|
||||||
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
|
||||||
observation = observation.to(device)
|
|
||||||
|
|
||||||
action, log_prob = agent.sample(observation)
|
|
||||||
pixel, reward, done, info = env.step(action)
|
|
||||||
rewards.append(reward)
|
|
||||||
log_probs.append(log_prob)
|
|
||||||
total_reward += reward
|
|
||||||
if done:
|
|
||||||
total_rewards.append(total_reward)
|
|
||||||
break
|
|
||||||
|
|
||||||
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
|
||||||
avgTotalRewards.append(avgTotalReward)
|
|
||||||
|
|
||||||
rewards = torch.tensor(rewards)
|
|
||||||
log_probs = torch.stack(log_probs)
|
|
||||||
|
|
||||||
print(rewards.shape)
|
|
||||||
print(log_probs.shape)
|
|
||||||
|
|
||||||
agent.learn(rewards, log_probs)
|
|
||||||
|
|
||||||
agent.save(PATH)
|
|
||||||
@ -4,10 +4,15 @@ class TetrisEnviroment():
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.game = TetrisGame()
|
self.game = TetrisGame()
|
||||||
self.score = 0
|
self.score = 0
|
||||||
|
self.pastHeights = 0
|
||||||
|
self.pastHoles = 0
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.game.reset()
|
self.game.reset()
|
||||||
self.score = 0
|
self.score = 0
|
||||||
|
self.pastHeights = 0
|
||||||
|
self.pastHoles = 0
|
||||||
|
|
||||||
return self.game.view(), 0, self.game.done, \
|
return self.game.view(), 0, self.game.done, \
|
||||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||||
|
|
||||||
@ -39,11 +44,48 @@ class TetrisEnviroment():
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.game.action('f')
|
self.game.action('f')
|
||||||
|
|
||||||
self.game.action('d')
|
fallStatus = self.game.action('d')
|
||||||
|
|
||||||
deltaScore = self.game.score - self.score
|
|
||||||
|
# 開始計算 rewards
|
||||||
|
# 1. 消除 line
|
||||||
|
completeLines = self.game.score - self.score
|
||||||
self.score = self.game.score
|
self.score = self.game.score
|
||||||
|
|
||||||
return self.game.view(), deltaScore, self.game.done, \
|
# 2. block 總高度
|
||||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
heights = self.game.getAggregateHeight()
|
||||||
# observation, reward, done, info(block location)
|
|
||||||
|
# 3. hole 數量
|
||||||
|
holes = self.game.getHoleNumber()
|
||||||
|
|
||||||
|
# 根據權重計算 rewards
|
||||||
|
# https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||||
|
rewards = -0.5*(heights-self.pastHeights) + 2*completeLines + -0.3*(holes-self.pastHoles)
|
||||||
|
if fallStatus == False:
|
||||||
|
rewards += 1
|
||||||
|
|
||||||
|
self.pastHeights = heights
|
||||||
|
self.pastHoles = holes
|
||||||
|
|
||||||
|
return self.game.view(), rewards, self.game.done, \
|
||||||
|
(self.game.block.x, self.game.block.y, self.game.block.block_id, completeLines)
|
||||||
|
# observation, reward, done, info(block location)
|
||||||
|
|
||||||
|
|
||||||
|
# env = TetrisEnviroment()
|
||||||
|
# pixel, reward, done, info = env.reset()
|
||||||
|
# while 1:
|
||||||
|
# for i in range(20):
|
||||||
|
# print(str(i).rjust(2), end=' ')
|
||||||
|
# for j in range(10):
|
||||||
|
# if pixel[i][j]:
|
||||||
|
# print('■', end='')
|
||||||
|
# else:
|
||||||
|
# print('□', end='')
|
||||||
|
# print()
|
||||||
|
# action = int(input("Action: "))
|
||||||
|
# pixel, reward, done, info = env.step(action)
|
||||||
|
# print(pixel, reward, done, info)
|
||||||
|
# print("Rewards: ", reward)
|
||||||
|
# if done:
|
||||||
|
# break
|
||||||
|
|||||||
56
game.py
56
game.py
@ -144,7 +144,8 @@ class TetrisGame():
|
|||||||
def action(self, mode):
|
def action(self, mode):
|
||||||
if mode == 'd':
|
if mode == 'd':
|
||||||
try:
|
try:
|
||||||
self.block.fall()
|
fallStatus = self.block.fall()
|
||||||
|
return fallStatus
|
||||||
except:
|
except:
|
||||||
print("GAME OVER")
|
print("GAME OVER")
|
||||||
self.done = True
|
self.done = True
|
||||||
@ -156,9 +157,62 @@ class TetrisGame():
|
|||||||
self.block.rotate()
|
self.block.rotate()
|
||||||
|
|
||||||
self.score += self.board.checkScore()
|
self.score += self.board.checkScore()
|
||||||
|
return True
|
||||||
|
|
||||||
def view(self):
|
def view(self):
|
||||||
return view(self.board, self.block)
|
return view(self.board, self.block)
|
||||||
|
|
||||||
|
def getAggregateHeight(self) -> int:
|
||||||
|
height = 0
|
||||||
|
for col in range(10):
|
||||||
|
for row in range(20):
|
||||||
|
if self.board.block[row][col] == 1:
|
||||||
|
if (20-row) > height:
|
||||||
|
height = (20-row)
|
||||||
|
break
|
||||||
|
return height
|
||||||
|
|
||||||
|
def getHoleNumber(self) -> int:
|
||||||
|
nowHoleNumber = 0
|
||||||
|
def checkLoc(row, col) -> bool:
|
||||||
|
if row < 0 or row > 19:
|
||||||
|
return False
|
||||||
|
if col < 0 or col > 9:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def updateCopyBoard(originBoard, flagBoard, row, col):
|
||||||
|
flagBoard[row][col] = nowHoleNumber
|
||||||
|
if checkLoc(row+1, col):
|
||||||
|
if originBoard[row+1][col] == 0 and flagBoard[row+1][col] == -1:
|
||||||
|
# print("CALL:", row, col, row+1, col)
|
||||||
|
updateCopyBoard(originBoard, flagBoard, row+1, col)
|
||||||
|
if checkLoc(row-1, col):
|
||||||
|
if originBoard[row-1][col] == 0 and flagBoard[row-1][col] == -1:
|
||||||
|
# print("CALL:", row, col, row-1, col)
|
||||||
|
updateCopyBoard(originBoard, flagBoard, row-1, col)
|
||||||
|
if checkLoc(row, col+1):
|
||||||
|
if originBoard[row][col+1] == 0 and flagBoard[row][col+1] == -1:
|
||||||
|
# print("CALL:", row, col, row, col+1)
|
||||||
|
updateCopyBoard(originBoard, flagBoard, row, col+1)
|
||||||
|
if checkLoc(row, col-1):
|
||||||
|
if originBoard[row][col-1] == 0 and flagBoard[row][col-1] == -1:
|
||||||
|
# print("CALL:", row, col, row, col-1)
|
||||||
|
updateCopyBoard(originBoard, flagBoard, row, col-1)
|
||||||
|
|
||||||
|
copyBoard = [[-1 for _ in range(10)] for _ in range(20)]
|
||||||
|
for row in range(20):
|
||||||
|
for col in range(10):
|
||||||
|
if self.board.block[row][col] == 0 and copyBoard[row][col] == -1:
|
||||||
|
updateCopyBoard(self.board.block, copyBoard, row, col)
|
||||||
|
nowHoleNumber += 1
|
||||||
|
# for row in range(20):
|
||||||
|
# for col in range(10):
|
||||||
|
# print(copyBoard[row][col], end=' ')
|
||||||
|
# print()
|
||||||
|
# print()
|
||||||
|
return nowHoleNumber-1
|
||||||
|
|
||||||
|
|
||||||
def view(board:Board, block:Block) -> list:
|
def view(board:Board, block:Block) -> list:
|
||||||
views = []
|
views = []
|
||||||
|
|||||||
87
train.py
Normal file
87
train.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from enviroment import TetrisEnviroment
|
||||||
|
from agent import TetrisRLAgent
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from enviroment import TetrisEnviroment
|
||||||
|
import time
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print("use:", device)
|
||||||
|
|
||||||
|
PATH = "model.h5"
|
||||||
|
|
||||||
|
def showGame(views:list, reward:int, score:int, completeLines:int) -> None:
|
||||||
|
for i in range(20):
|
||||||
|
print(str(i).rjust(2), end=' ')
|
||||||
|
for j in range(10):
|
||||||
|
if views[i][j]:
|
||||||
|
print('■', end='')
|
||||||
|
else:
|
||||||
|
print('□', end='')
|
||||||
|
print()
|
||||||
|
print("reward:", reward)
|
||||||
|
print("total_reward:", score)
|
||||||
|
print("total_line :", completeLines)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
env = TetrisEnviroment()
|
||||||
|
agent = TetrisRLAgent(device)
|
||||||
|
|
||||||
|
avgTotalRewards = []
|
||||||
|
for batch in range(100000):
|
||||||
|
|
||||||
|
log_probs, rewards = [], []
|
||||||
|
total_rewards = []
|
||||||
|
|
||||||
|
for episode in range(5):
|
||||||
|
total_reward = 0
|
||||||
|
total_lines = 0
|
||||||
|
|
||||||
|
pixel, reward, done, info = env.reset()
|
||||||
|
while 1:
|
||||||
|
showGame(pixel, reward, total_reward, total_lines)
|
||||||
|
|
||||||
|
pixel = torch.tensor(pixel, dtype=torch.float32)
|
||||||
|
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
||||||
|
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
||||||
|
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
||||||
|
observation = observation.to(device)
|
||||||
|
|
||||||
|
action, log_prob = agent.sample(observation)
|
||||||
|
pixel, reward, done, info = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
log_probs.append(log_prob)
|
||||||
|
total_reward += reward
|
||||||
|
total_lines += info[3]
|
||||||
|
if done:
|
||||||
|
total_rewards.append(total_reward)
|
||||||
|
break
|
||||||
|
|
||||||
|
print("TOTAL REWARDS", total_rewards)
|
||||||
|
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
||||||
|
avgTotalRewards.append(avgTotalReward)
|
||||||
|
|
||||||
|
ALPHA = 0.98
|
||||||
|
delayRewards = []
|
||||||
|
for start in range(len(rewards)):
|
||||||
|
ans = rewards[start]
|
||||||
|
weight = ALPHA
|
||||||
|
for i in range(start+1, len(rewards)):
|
||||||
|
ans += (weight * rewards[i])
|
||||||
|
weight *= ALPHA
|
||||||
|
delayRewards.append(ans)
|
||||||
|
|
||||||
|
rewards = torch.tensor(delayRewards)
|
||||||
|
log_probs = torch.stack(log_probs)
|
||||||
|
|
||||||
|
print(batch)
|
||||||
|
print(rewards.shape)
|
||||||
|
print(log_probs.shape)
|
||||||
|
print(avgTotalReward)
|
||||||
|
# print(rewards.tolist())
|
||||||
|
|
||||||
|
agent.learn(rewards, log_probs)
|
||||||
|
# time.sleep(5)
|
||||||
|
|
||||||
|
agent.save(PATH)
|
||||||
Loading…
Reference in New Issue
Block a user