From 7c07b4aa96d1e635ede8e6ff0387f93308631bb1 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Fri, 1 Jul 2022 03:34:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20RL=20model=20&=20A?= =?UTF-8?q?gent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent.py | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++ enviroment.py | 32 ++++++----------- game.py | 10 ++++++ network.py | 17 +++++++++ 4 files changed, 133 insertions(+), 22 deletions(-) create mode 100644 agent.py create mode 100644 network.py diff --git a/agent.py b/agent.py new file mode 100644 index 0000000..3eee57e --- /dev/null +++ b/agent.py @@ -0,0 +1,96 @@ +import torch +from torch.optim import Adam +from torch.distributions import Categorical +from zmq import device +from enviroment import TetrisEnviroment +from network import TetrisRLModel +import torch.nn.functional as F + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("use:", device) + +PATH = "model.h5" + +class TetrisRLAgent(): + def __init__(self) -> None: + self.model = TetrisRLModel() + self.model = self.model.to(device) + self.optim = Adam(self.model.parameters(), lr=0.001) + + def sample(self, observation): + action_prob = self.model(observation) + action_dist = Categorical(action_prob) + action = action_dist.sample() + log_prob = action_dist.log_prob(action) + return action, log_prob + + def learn(self, rewards, log_probes): + rewards = rewards.to(device) + loss = ((-log_probes * rewards)).sum() + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + def save(self, PATH): # You should not revise this + Agent_Dict = { + "network" : self.network.state_dict(), + "optimizer" : self.optimizer.state_dict() + } + torch.save(Agent_Dict, PATH) + +def showGame(views:list, score:int) -> None: + for i in range(20): + print(str(i).rjust(2), end=' ') + for j in range(10): + if views[i][j]: + print('■', end='') + else: + print('□', end='') + print() + print("Score:", score) + print() + +env = TetrisEnviroment() +agent = TetrisRLAgent() + +avgTotalRewards = [] +for batch in range(100000): + + log_probs, rewards = [], [] + total_rewards = [] + + for episode in range(5): + total_reward = 0 + + pixel, reward, done, info = env.reset() + while 1: + showGame(pixel, total_reward) + + pixel = torch.tensor(pixel, dtype=torch.float32) + blockType = F.one_hot(torch.tensor(info[2]), 7) + blockLoc = torch.tensor(info[:2], dtype=torch.float32) + observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0) + observation = observation.to(device) + + action, log_prob = agent.sample(observation) + pixel, reward, done, info = env.step(action) + rewards.append(reward) + log_probs.append(log_prob) + total_reward += reward + if done: + total_rewards.append(total_reward) + break + + avgTotalReward = sum(total_rewards) / len(total_rewards) + avgTotalRewards.append(avgTotalReward) + + rewards = torch.tensor(rewards) + log_probs = torch.stack(log_probs) + + print(rewards.shape) + print(log_probs.shape) + + agent.learn(rewards, log_probs) + +agent.save(PATH) \ No newline at end of file diff --git a/enviroment.py b/enviroment.py index 4960b93..21d7cd7 100644 --- a/enviroment.py +++ b/enviroment.py @@ -5,6 +5,12 @@ class TetrisEnviroment(): self.game = TetrisGame() self.score = 0 + def reset(self): + self.game.reset() + self.score = 0 + return self.game.view(), 0, self.game.done, \ + (self.game.block.x, self.game.block.y, self.game.block.block_id) + def step(self, mode): if mode == 0: # 不動 None @@ -32,30 +38,12 @@ class TetrisEnviroment(): elif mode == 9: # rotate 3 for i in range(3): self.game.action('f') + self.game.action('d') deltaScore = self.game.score - self.score self.score = self.game.score - return self.game.view(), deltaScore, self.game.done, (self.game.block.x, self.game.block.block.y) - # observation, reward, done, info(block location) - - def observation(self): - return self.game.view() - -env = TetrisEnviroment() -while 1: - views = env.observation() - for i in range(20): - print(str(i).rjust(2), end=' ') - for j in range(10): - if views[i][j]: - print('■', end='') - else: - print('□', end='') - print() - action = int(input("Action: ")) - observation, reward, done = env.step(action) - print(observation, reward, done) - if done: - break \ No newline at end of file + return self.game.view(), deltaScore, self.game.done, \ + (self.game.block.x, self.game.block.y, self.game.block.block_id) + # observation, reward, done, info(block location) \ No newline at end of file diff --git a/game.py b/game.py index d784ba4..6e7a3f2 100644 --- a/game.py +++ b/game.py @@ -30,6 +30,9 @@ class Board(): 把現在的方塊擺進去 board,代表確定方塊位置了 ''' self.block = view(self, block) + + def reset(self) -> None: + self.block = [[0 for _ in range(10)] for _ in range(20)] def checkScore(self) -> int: ''' @@ -131,6 +134,13 @@ class TetrisGame(): self.score = 0 self.done = False + def reset(self) -> None: + self.block.reset() + self.board.reset() + + self.score = 0 + self.done = False + def action(self, mode): if mode == 'd': try: diff --git a/network.py b/network.py new file mode 100644 index 0000000..e15f941 --- /dev/null +++ b/network.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class TetrisRLModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features=209, out_features=512) + self.fc2 = nn.Linear(in_features=512, out_features=32) + self.fc3 = nn.Linear(in_features=32, out_features=10) + + def forward(self, observation): + observation = observation.to(dtype=torch.float32) + observation = torch.relu(self.fc1(observation)) + observation = torch.relu(self.fc2(observation)) + observation = F.softmax(self.fc3(observation), dim=0) + return observation