feat: 新增 RL model & Agent

This commit is contained in:
Ting-Jun Wang 2022-07-01 03:34:50 +08:00
parent 6cfd08095e
commit 7c07b4aa96
Signed by: snsd0805
GPG Key ID: 8DB0D22BC1217D33
4 changed files with 133 additions and 22 deletions

96
agent.py Normal file
View File

@ -0,0 +1,96 @@
import torch
from torch.optim import Adam
from torch.distributions import Categorical
from zmq import device
from enviroment import TetrisEnviroment
from network import TetrisRLModel
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("use:", device)
PATH = "model.h5"
class TetrisRLAgent():
def __init__(self) -> None:
self.model = TetrisRLModel()
self.model = self.model.to(device)
self.optim = Adam(self.model.parameters(), lr=0.001)
def sample(self, observation):
action_prob = self.model(observation)
action_dist = Categorical(action_prob)
action = action_dist.sample()
log_prob = action_dist.log_prob(action)
return action, log_prob
def learn(self, rewards, log_probes):
rewards = rewards.to(device)
loss = ((-log_probes * rewards)).sum()
self.optim.zero_grad()
loss.backward()
self.optim.step()
def save(self, PATH): # You should not revise this
Agent_Dict = {
"network" : self.network.state_dict(),
"optimizer" : self.optimizer.state_dict()
}
torch.save(Agent_Dict, PATH)
def showGame(views:list, score:int) -> None:
for i in range(20):
print(str(i).rjust(2), end=' ')
for j in range(10):
if views[i][j]:
print('', end='')
else:
print('', end='')
print()
print("Score:", score)
print()
env = TetrisEnviroment()
agent = TetrisRLAgent()
avgTotalRewards = []
for batch in range(100000):
log_probs, rewards = [], []
total_rewards = []
for episode in range(5):
total_reward = 0
pixel, reward, done, info = env.reset()
while 1:
showGame(pixel, total_reward)
pixel = torch.tensor(pixel, dtype=torch.float32)
blockType = F.one_hot(torch.tensor(info[2]), 7)
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
observation = observation.to(device)
action, log_prob = agent.sample(observation)
pixel, reward, done, info = env.step(action)
rewards.append(reward)
log_probs.append(log_prob)
total_reward += reward
if done:
total_rewards.append(total_reward)
break
avgTotalReward = sum(total_rewards) / len(total_rewards)
avgTotalRewards.append(avgTotalReward)
rewards = torch.tensor(rewards)
log_probs = torch.stack(log_probs)
print(rewards.shape)
print(log_probs.shape)
agent.learn(rewards, log_probs)
agent.save(PATH)

View File

@ -5,6 +5,12 @@ class TetrisEnviroment():
self.game = TetrisGame()
self.score = 0
def reset(self):
self.game.reset()
self.score = 0
return self.game.view(), 0, self.game.done, \
(self.game.block.x, self.game.block.y, self.game.block.block_id)
def step(self, mode):
if mode == 0: # 不動
None
@ -32,30 +38,12 @@ class TetrisEnviroment():
elif mode == 9: # rotate 3
for i in range(3):
self.game.action('f')
self.game.action('d')
deltaScore = self.game.score - self.score
self.score = self.game.score
return self.game.view(), deltaScore, self.game.done, (self.game.block.x, self.game.block.block.y)
# observation, reward, done, info(block location)
def observation(self):
return self.game.view()
env = TetrisEnviroment()
while 1:
views = env.observation()
for i in range(20):
print(str(i).rjust(2), end=' ')
for j in range(10):
if views[i][j]:
print('', end='')
else:
print('', end='')
print()
action = int(input("Action: "))
observation, reward, done = env.step(action)
print(observation, reward, done)
if done:
break
return self.game.view(), deltaScore, self.game.done, \
(self.game.block.x, self.game.block.y, self.game.block.block_id)
# observation, reward, done, info(block location)

10
game.py
View File

@ -30,6 +30,9 @@ class Board():
把現在的方塊擺進去 board代表確定方塊位置了
'''
self.block = view(self, block)
def reset(self) -> None:
self.block = [[0 for _ in range(10)] for _ in range(20)]
def checkScore(self) -> int:
'''
@ -131,6 +134,13 @@ class TetrisGame():
self.score = 0
self.done = False
def reset(self) -> None:
self.block.reset()
self.board.reset()
self.score = 0
self.done = False
def action(self, mode):
if mode == 'd':
try:

17
network.py Normal file
View File

@ -0,0 +1,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class TetrisRLModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features=209, out_features=512)
self.fc2 = nn.Linear(in_features=512, out_features=32)
self.fc3 = nn.Linear(in_features=32, out_features=10)
def forward(self, observation):
observation = observation.to(dtype=torch.float32)
observation = torch.relu(self.fc1(observation))
observation = torch.relu(self.fc2(observation))
observation = F.softmax(self.fc3(observation), dim=0)
return observation