feat: 新增 RL model & Agent
This commit is contained in:
parent
6cfd08095e
commit
7c07b4aa96
96
agent.py
Normal file
96
agent.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import torch
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from zmq import device
|
||||||
|
from enviroment import TetrisEnviroment
|
||||||
|
from network import TetrisRLModel
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print("use:", device)
|
||||||
|
|
||||||
|
PATH = "model.h5"
|
||||||
|
|
||||||
|
class TetrisRLAgent():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = TetrisRLModel()
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
self.optim = Adam(self.model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
def sample(self, observation):
|
||||||
|
action_prob = self.model(observation)
|
||||||
|
action_dist = Categorical(action_prob)
|
||||||
|
action = action_dist.sample()
|
||||||
|
log_prob = action_dist.log_prob(action)
|
||||||
|
return action, log_prob
|
||||||
|
|
||||||
|
def learn(self, rewards, log_probes):
|
||||||
|
rewards = rewards.to(device)
|
||||||
|
loss = ((-log_probes * rewards)).sum()
|
||||||
|
|
||||||
|
self.optim.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
|
||||||
|
def save(self, PATH): # You should not revise this
|
||||||
|
Agent_Dict = {
|
||||||
|
"network" : self.network.state_dict(),
|
||||||
|
"optimizer" : self.optimizer.state_dict()
|
||||||
|
}
|
||||||
|
torch.save(Agent_Dict, PATH)
|
||||||
|
|
||||||
|
def showGame(views:list, score:int) -> None:
|
||||||
|
for i in range(20):
|
||||||
|
print(str(i).rjust(2), end=' ')
|
||||||
|
for j in range(10):
|
||||||
|
if views[i][j]:
|
||||||
|
print('■', end='')
|
||||||
|
else:
|
||||||
|
print('□', end='')
|
||||||
|
print()
|
||||||
|
print("Score:", score)
|
||||||
|
print()
|
||||||
|
|
||||||
|
env = TetrisEnviroment()
|
||||||
|
agent = TetrisRLAgent()
|
||||||
|
|
||||||
|
avgTotalRewards = []
|
||||||
|
for batch in range(100000):
|
||||||
|
|
||||||
|
log_probs, rewards = [], []
|
||||||
|
total_rewards = []
|
||||||
|
|
||||||
|
for episode in range(5):
|
||||||
|
total_reward = 0
|
||||||
|
|
||||||
|
pixel, reward, done, info = env.reset()
|
||||||
|
while 1:
|
||||||
|
showGame(pixel, total_reward)
|
||||||
|
|
||||||
|
pixel = torch.tensor(pixel, dtype=torch.float32)
|
||||||
|
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
||||||
|
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
||||||
|
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
||||||
|
observation = observation.to(device)
|
||||||
|
|
||||||
|
action, log_prob = agent.sample(observation)
|
||||||
|
pixel, reward, done, info = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
log_probs.append(log_prob)
|
||||||
|
total_reward += reward
|
||||||
|
if done:
|
||||||
|
total_rewards.append(total_reward)
|
||||||
|
break
|
||||||
|
|
||||||
|
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
||||||
|
avgTotalRewards.append(avgTotalReward)
|
||||||
|
|
||||||
|
rewards = torch.tensor(rewards)
|
||||||
|
log_probs = torch.stack(log_probs)
|
||||||
|
|
||||||
|
print(rewards.shape)
|
||||||
|
print(log_probs.shape)
|
||||||
|
|
||||||
|
agent.learn(rewards, log_probs)
|
||||||
|
|
||||||
|
agent.save(PATH)
|
||||||
@ -5,6 +5,12 @@ class TetrisEnviroment():
|
|||||||
self.game = TetrisGame()
|
self.game = TetrisGame()
|
||||||
self.score = 0
|
self.score = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.game.reset()
|
||||||
|
self.score = 0
|
||||||
|
return self.game.view(), 0, self.game.done, \
|
||||||
|
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||||
|
|
||||||
def step(self, mode):
|
def step(self, mode):
|
||||||
if mode == 0: # 不動
|
if mode == 0: # 不動
|
||||||
None
|
None
|
||||||
@ -32,30 +38,12 @@ class TetrisEnviroment():
|
|||||||
elif mode == 9: # rotate 3
|
elif mode == 9: # rotate 3
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.game.action('f')
|
self.game.action('f')
|
||||||
|
|
||||||
self.game.action('d')
|
self.game.action('d')
|
||||||
|
|
||||||
deltaScore = self.game.score - self.score
|
deltaScore = self.game.score - self.score
|
||||||
self.score = self.game.score
|
self.score = self.game.score
|
||||||
|
|
||||||
return self.game.view(), deltaScore, self.game.done, (self.game.block.x, self.game.block.block.y)
|
return self.game.view(), deltaScore, self.game.done, \
|
||||||
|
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||||
# observation, reward, done, info(block location)
|
# observation, reward, done, info(block location)
|
||||||
|
|
||||||
def observation(self):
|
|
||||||
return self.game.view()
|
|
||||||
|
|
||||||
env = TetrisEnviroment()
|
|
||||||
while 1:
|
|
||||||
views = env.observation()
|
|
||||||
for i in range(20):
|
|
||||||
print(str(i).rjust(2), end=' ')
|
|
||||||
for j in range(10):
|
|
||||||
if views[i][j]:
|
|
||||||
print('■', end='')
|
|
||||||
else:
|
|
||||||
print('□', end='')
|
|
||||||
print()
|
|
||||||
action = int(input("Action: "))
|
|
||||||
observation, reward, done = env.step(action)
|
|
||||||
print(observation, reward, done)
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
10
game.py
10
game.py
@ -31,6 +31,9 @@ class Board():
|
|||||||
'''
|
'''
|
||||||
self.block = view(self, block)
|
self.block = view(self, block)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.block = [[0 for _ in range(10)] for _ in range(20)]
|
||||||
|
|
||||||
def checkScore(self) -> int:
|
def checkScore(self) -> int:
|
||||||
'''
|
'''
|
||||||
檢查是否有連線,並回傳該動作得到的分數
|
檢查是否有連線,並回傳該動作得到的分數
|
||||||
@ -131,6 +134,13 @@ class TetrisGame():
|
|||||||
self.score = 0
|
self.score = 0
|
||||||
self.done = False
|
self.done = False
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.block.reset()
|
||||||
|
self.board.reset()
|
||||||
|
|
||||||
|
self.score = 0
|
||||||
|
self.done = False
|
||||||
|
|
||||||
def action(self, mode):
|
def action(self, mode):
|
||||||
if mode == 'd':
|
if mode == 'd':
|
||||||
try:
|
try:
|
||||||
|
|||||||
17
network.py
Normal file
17
network.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class TetrisRLModel(nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(in_features=209, out_features=512)
|
||||||
|
self.fc2 = nn.Linear(in_features=512, out_features=32)
|
||||||
|
self.fc3 = nn.Linear(in_features=32, out_features=10)
|
||||||
|
|
||||||
|
def forward(self, observation):
|
||||||
|
observation = observation.to(dtype=torch.float32)
|
||||||
|
observation = torch.relu(self.fc1(observation))
|
||||||
|
observation = torch.relu(self.fc2(observation))
|
||||||
|
observation = F.softmax(self.fc3(observation), dim=0)
|
||||||
|
return observation
|
||||||
Loading…
Reference in New Issue
Block a user