feat: Policy Gradient, Rewards Delay
This commit is contained in:
parent
7c07b4aa96
commit
1366659a41
70
agent.py
70
agent.py
@ -1,20 +1,13 @@
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.distributions import Categorical
|
||||
from zmq import device
|
||||
from enviroment import TetrisEnviroment
|
||||
from network import TetrisRLModel
|
||||
import torch.nn.functional as F
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print("use:", device)
|
||||
|
||||
PATH = "model.h5"
|
||||
|
||||
class TetrisRLAgent():
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, device) -> None:
|
||||
self.device = device
|
||||
self.model = TetrisRLModel()
|
||||
self.model = self.model.to(device)
|
||||
self.model = self.model.to(self.device)
|
||||
self.optim = Adam(self.model.parameters(), lr=0.001)
|
||||
|
||||
def sample(self, observation):
|
||||
@ -25,7 +18,7 @@ class TetrisRLAgent():
|
||||
return action, log_prob
|
||||
|
||||
def learn(self, rewards, log_probes):
|
||||
rewards = rewards.to(device)
|
||||
rewards = rewards.to(self.device)
|
||||
loss = ((-log_probes * rewards)).sum()
|
||||
|
||||
self.optim.zero_grad()
|
||||
@ -39,58 +32,3 @@ class TetrisRLAgent():
|
||||
}
|
||||
torch.save(Agent_Dict, PATH)
|
||||
|
||||
def showGame(views:list, score:int) -> None:
|
||||
for i in range(20):
|
||||
print(str(i).rjust(2), end=' ')
|
||||
for j in range(10):
|
||||
if views[i][j]:
|
||||
print('■', end='')
|
||||
else:
|
||||
print('□', end='')
|
||||
print()
|
||||
print("Score:", score)
|
||||
print()
|
||||
|
||||
env = TetrisEnviroment()
|
||||
agent = TetrisRLAgent()
|
||||
|
||||
avgTotalRewards = []
|
||||
for batch in range(100000):
|
||||
|
||||
log_probs, rewards = [], []
|
||||
total_rewards = []
|
||||
|
||||
for episode in range(5):
|
||||
total_reward = 0
|
||||
|
||||
pixel, reward, done, info = env.reset()
|
||||
while 1:
|
||||
showGame(pixel, total_reward)
|
||||
|
||||
pixel = torch.tensor(pixel, dtype=torch.float32)
|
||||
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
||||
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
||||
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
||||
observation = observation.to(device)
|
||||
|
||||
action, log_prob = agent.sample(observation)
|
||||
pixel, reward, done, info = env.step(action)
|
||||
rewards.append(reward)
|
||||
log_probs.append(log_prob)
|
||||
total_reward += reward
|
||||
if done:
|
||||
total_rewards.append(total_reward)
|
||||
break
|
||||
|
||||
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
||||
avgTotalRewards.append(avgTotalReward)
|
||||
|
||||
rewards = torch.tensor(rewards)
|
||||
log_probs = torch.stack(log_probs)
|
||||
|
||||
print(rewards.shape)
|
||||
print(log_probs.shape)
|
||||
|
||||
agent.learn(rewards, log_probs)
|
||||
|
||||
agent.save(PATH)
|
||||
@ -4,10 +4,15 @@ class TetrisEnviroment():
|
||||
def __init__(self) -> None:
|
||||
self.game = TetrisGame()
|
||||
self.score = 0
|
||||
self.pastHeights = 0
|
||||
self.pastHoles = 0
|
||||
|
||||
def reset(self):
|
||||
self.game.reset()
|
||||
self.score = 0
|
||||
self.pastHeights = 0
|
||||
self.pastHoles = 0
|
||||
|
||||
return self.game.view(), 0, self.game.done, \
|
||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||
|
||||
@ -39,11 +44,48 @@ class TetrisEnviroment():
|
||||
for i in range(3):
|
||||
self.game.action('f')
|
||||
|
||||
self.game.action('d')
|
||||
fallStatus = self.game.action('d')
|
||||
|
||||
deltaScore = self.game.score - self.score
|
||||
|
||||
# 開始計算 rewards
|
||||
# 1. 消除 line
|
||||
completeLines = self.game.score - self.score
|
||||
self.score = self.game.score
|
||||
|
||||
return self.game.view(), deltaScore, self.game.done, \
|
||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||
# observation, reward, done, info(block location)
|
||||
# 2. block 總高度
|
||||
heights = self.game.getAggregateHeight()
|
||||
|
||||
# 3. hole 數量
|
||||
holes = self.game.getHoleNumber()
|
||||
|
||||
# 根據權重計算 rewards
|
||||
# https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||
rewards = -0.5*(heights-self.pastHeights) + 2*completeLines + -0.3*(holes-self.pastHoles)
|
||||
if fallStatus == False:
|
||||
rewards += 1
|
||||
|
||||
self.pastHeights = heights
|
||||
self.pastHoles = holes
|
||||
|
||||
return self.game.view(), rewards, self.game.done, \
|
||||
(self.game.block.x, self.game.block.y, self.game.block.block_id, completeLines)
|
||||
# observation, reward, done, info(block location)
|
||||
|
||||
|
||||
# env = TetrisEnviroment()
|
||||
# pixel, reward, done, info = env.reset()
|
||||
# while 1:
|
||||
# for i in range(20):
|
||||
# print(str(i).rjust(2), end=' ')
|
||||
# for j in range(10):
|
||||
# if pixel[i][j]:
|
||||
# print('■', end='')
|
||||
# else:
|
||||
# print('□', end='')
|
||||
# print()
|
||||
# action = int(input("Action: "))
|
||||
# pixel, reward, done, info = env.step(action)
|
||||
# print(pixel, reward, done, info)
|
||||
# print("Rewards: ", reward)
|
||||
# if done:
|
||||
# break
|
||||
|
||||
56
game.py
56
game.py
@ -144,7 +144,8 @@ class TetrisGame():
|
||||
def action(self, mode):
|
||||
if mode == 'd':
|
||||
try:
|
||||
self.block.fall()
|
||||
fallStatus = self.block.fall()
|
||||
return fallStatus
|
||||
except:
|
||||
print("GAME OVER")
|
||||
self.done = True
|
||||
@ -156,9 +157,62 @@ class TetrisGame():
|
||||
self.block.rotate()
|
||||
|
||||
self.score += self.board.checkScore()
|
||||
return True
|
||||
|
||||
def view(self):
|
||||
return view(self.board, self.block)
|
||||
|
||||
def getAggregateHeight(self) -> int:
|
||||
height = 0
|
||||
for col in range(10):
|
||||
for row in range(20):
|
||||
if self.board.block[row][col] == 1:
|
||||
if (20-row) > height:
|
||||
height = (20-row)
|
||||
break
|
||||
return height
|
||||
|
||||
def getHoleNumber(self) -> int:
|
||||
nowHoleNumber = 0
|
||||
def checkLoc(row, col) -> bool:
|
||||
if row < 0 or row > 19:
|
||||
return False
|
||||
if col < 0 or col > 9:
|
||||
return False
|
||||
return True
|
||||
|
||||
def updateCopyBoard(originBoard, flagBoard, row, col):
|
||||
flagBoard[row][col] = nowHoleNumber
|
||||
if checkLoc(row+1, col):
|
||||
if originBoard[row+1][col] == 0 and flagBoard[row+1][col] == -1:
|
||||
# print("CALL:", row, col, row+1, col)
|
||||
updateCopyBoard(originBoard, flagBoard, row+1, col)
|
||||
if checkLoc(row-1, col):
|
||||
if originBoard[row-1][col] == 0 and flagBoard[row-1][col] == -1:
|
||||
# print("CALL:", row, col, row-1, col)
|
||||
updateCopyBoard(originBoard, flagBoard, row-1, col)
|
||||
if checkLoc(row, col+1):
|
||||
if originBoard[row][col+1] == 0 and flagBoard[row][col+1] == -1:
|
||||
# print("CALL:", row, col, row, col+1)
|
||||
updateCopyBoard(originBoard, flagBoard, row, col+1)
|
||||
if checkLoc(row, col-1):
|
||||
if originBoard[row][col-1] == 0 and flagBoard[row][col-1] == -1:
|
||||
# print("CALL:", row, col, row, col-1)
|
||||
updateCopyBoard(originBoard, flagBoard, row, col-1)
|
||||
|
||||
copyBoard = [[-1 for _ in range(10)] for _ in range(20)]
|
||||
for row in range(20):
|
||||
for col in range(10):
|
||||
if self.board.block[row][col] == 0 and copyBoard[row][col] == -1:
|
||||
updateCopyBoard(self.board.block, copyBoard, row, col)
|
||||
nowHoleNumber += 1
|
||||
# for row in range(20):
|
||||
# for col in range(10):
|
||||
# print(copyBoard[row][col], end=' ')
|
||||
# print()
|
||||
# print()
|
||||
return nowHoleNumber-1
|
||||
|
||||
|
||||
def view(board:Board, block:Block) -> list:
|
||||
views = []
|
||||
|
||||
87
train.py
Normal file
87
train.py
Normal file
@ -0,0 +1,87 @@
|
||||
from enviroment import TetrisEnviroment
|
||||
from agent import TetrisRLAgent
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from enviroment import TetrisEnviroment
|
||||
import time
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print("use:", device)
|
||||
|
||||
PATH = "model.h5"
|
||||
|
||||
def showGame(views:list, reward:int, score:int, completeLines:int) -> None:
|
||||
for i in range(20):
|
||||
print(str(i).rjust(2), end=' ')
|
||||
for j in range(10):
|
||||
if views[i][j]:
|
||||
print('■', end='')
|
||||
else:
|
||||
print('□', end='')
|
||||
print()
|
||||
print("reward:", reward)
|
||||
print("total_reward:", score)
|
||||
print("total_line :", completeLines)
|
||||
print()
|
||||
|
||||
|
||||
env = TetrisEnviroment()
|
||||
agent = TetrisRLAgent(device)
|
||||
|
||||
avgTotalRewards = []
|
||||
for batch in range(100000):
|
||||
|
||||
log_probs, rewards = [], []
|
||||
total_rewards = []
|
||||
|
||||
for episode in range(5):
|
||||
total_reward = 0
|
||||
total_lines = 0
|
||||
|
||||
pixel, reward, done, info = env.reset()
|
||||
while 1:
|
||||
showGame(pixel, reward, total_reward, total_lines)
|
||||
|
||||
pixel = torch.tensor(pixel, dtype=torch.float32)
|
||||
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
||||
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
||||
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
||||
observation = observation.to(device)
|
||||
|
||||
action, log_prob = agent.sample(observation)
|
||||
pixel, reward, done, info = env.step(action)
|
||||
rewards.append(reward)
|
||||
log_probs.append(log_prob)
|
||||
total_reward += reward
|
||||
total_lines += info[3]
|
||||
if done:
|
||||
total_rewards.append(total_reward)
|
||||
break
|
||||
|
||||
print("TOTAL REWARDS", total_rewards)
|
||||
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
||||
avgTotalRewards.append(avgTotalReward)
|
||||
|
||||
ALPHA = 0.98
|
||||
delayRewards = []
|
||||
for start in range(len(rewards)):
|
||||
ans = rewards[start]
|
||||
weight = ALPHA
|
||||
for i in range(start+1, len(rewards)):
|
||||
ans += (weight * rewards[i])
|
||||
weight *= ALPHA
|
||||
delayRewards.append(ans)
|
||||
|
||||
rewards = torch.tensor(delayRewards)
|
||||
log_probs = torch.stack(log_probs)
|
||||
|
||||
print(batch)
|
||||
print(rewards.shape)
|
||||
print(log_probs.shape)
|
||||
print(avgTotalReward)
|
||||
# print(rewards.tolist())
|
||||
|
||||
agent.learn(rewards, log_probs)
|
||||
# time.sleep(5)
|
||||
|
||||
agent.save(PATH)
|
||||
Loading…
Reference in New Issue
Block a user