feat: 新增 RL model & Agent
This commit is contained in:
parent
6cfd08095e
commit
7c07b4aa96
96
agent.py
Normal file
96
agent.py
Normal file
@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.distributions import Categorical
|
||||
from zmq import device
|
||||
from enviroment import TetrisEnviroment
|
||||
from network import TetrisRLModel
|
||||
import torch.nn.functional as F
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print("use:", device)
|
||||
|
||||
PATH = "model.h5"
|
||||
|
||||
class TetrisRLAgent():
|
||||
def __init__(self) -> None:
|
||||
self.model = TetrisRLModel()
|
||||
self.model = self.model.to(device)
|
||||
self.optim = Adam(self.model.parameters(), lr=0.001)
|
||||
|
||||
def sample(self, observation):
|
||||
action_prob = self.model(observation)
|
||||
action_dist = Categorical(action_prob)
|
||||
action = action_dist.sample()
|
||||
log_prob = action_dist.log_prob(action)
|
||||
return action, log_prob
|
||||
|
||||
def learn(self, rewards, log_probes):
|
||||
rewards = rewards.to(device)
|
||||
loss = ((-log_probes * rewards)).sum()
|
||||
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
||||
def save(self, PATH): # You should not revise this
|
||||
Agent_Dict = {
|
||||
"network" : self.network.state_dict(),
|
||||
"optimizer" : self.optimizer.state_dict()
|
||||
}
|
||||
torch.save(Agent_Dict, PATH)
|
||||
|
||||
def showGame(views:list, score:int) -> None:
|
||||
for i in range(20):
|
||||
print(str(i).rjust(2), end=' ')
|
||||
for j in range(10):
|
||||
if views[i][j]:
|
||||
print('■', end='')
|
||||
else:
|
||||
print('□', end='')
|
||||
print()
|
||||
print("Score:", score)
|
||||
print()
|
||||
|
||||
env = TetrisEnviroment()
|
||||
agent = TetrisRLAgent()
|
||||
|
||||
avgTotalRewards = []
|
||||
for batch in range(100000):
|
||||
|
||||
log_probs, rewards = [], []
|
||||
total_rewards = []
|
||||
|
||||
for episode in range(5):
|
||||
total_reward = 0
|
||||
|
||||
pixel, reward, done, info = env.reset()
|
||||
while 1:
|
||||
showGame(pixel, total_reward)
|
||||
|
||||
pixel = torch.tensor(pixel, dtype=torch.float32)
|
||||
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
||||
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
||||
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
||||
observation = observation.to(device)
|
||||
|
||||
action, log_prob = agent.sample(observation)
|
||||
pixel, reward, done, info = env.step(action)
|
||||
rewards.append(reward)
|
||||
log_probs.append(log_prob)
|
||||
total_reward += reward
|
||||
if done:
|
||||
total_rewards.append(total_reward)
|
||||
break
|
||||
|
||||
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
||||
avgTotalRewards.append(avgTotalReward)
|
||||
|
||||
rewards = torch.tensor(rewards)
|
||||
log_probs = torch.stack(log_probs)
|
||||
|
||||
print(rewards.shape)
|
||||
print(log_probs.shape)
|
||||
|
||||
agent.learn(rewards, log_probs)
|
||||
|
||||
agent.save(PATH)
|
||||
@ -5,6 +5,12 @@ class TetrisEnviroment():
|
||||
self.game = TetrisGame()
|
||||
self.score = 0
|
||||
|
||||
def reset(self):
|
||||
self.game.reset()
|
||||
self.score = 0
|
||||
return self.game.view(), 0, self.game.done, \
|
||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||
|
||||
def step(self, mode):
|
||||
if mode == 0: # 不動
|
||||
None
|
||||
@ -32,30 +38,12 @@ class TetrisEnviroment():
|
||||
elif mode == 9: # rotate 3
|
||||
for i in range(3):
|
||||
self.game.action('f')
|
||||
|
||||
self.game.action('d')
|
||||
|
||||
deltaScore = self.game.score - self.score
|
||||
self.score = self.game.score
|
||||
|
||||
return self.game.view(), deltaScore, self.game.done, (self.game.block.x, self.game.block.block.y)
|
||||
# observation, reward, done, info(block location)
|
||||
|
||||
def observation(self):
|
||||
return self.game.view()
|
||||
|
||||
env = TetrisEnviroment()
|
||||
while 1:
|
||||
views = env.observation()
|
||||
for i in range(20):
|
||||
print(str(i).rjust(2), end=' ')
|
||||
for j in range(10):
|
||||
if views[i][j]:
|
||||
print('■', end='')
|
||||
else:
|
||||
print('□', end='')
|
||||
print()
|
||||
action = int(input("Action: "))
|
||||
observation, reward, done = env.step(action)
|
||||
print(observation, reward, done)
|
||||
if done:
|
||||
break
|
||||
return self.game.view(), deltaScore, self.game.done, \
|
||||
(self.game.block.x, self.game.block.y, self.game.block.block_id)
|
||||
# observation, reward, done, info(block location)
|
||||
10
game.py
10
game.py
@ -30,6 +30,9 @@ class Board():
|
||||
把現在的方塊擺進去 board,代表確定方塊位置了
|
||||
'''
|
||||
self.block = view(self, block)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.block = [[0 for _ in range(10)] for _ in range(20)]
|
||||
|
||||
def checkScore(self) -> int:
|
||||
'''
|
||||
@ -131,6 +134,13 @@ class TetrisGame():
|
||||
self.score = 0
|
||||
self.done = False
|
||||
|
||||
def reset(self) -> None:
|
||||
self.block.reset()
|
||||
self.board.reset()
|
||||
|
||||
self.score = 0
|
||||
self.done = False
|
||||
|
||||
def action(self, mode):
|
||||
if mode == 'd':
|
||||
try:
|
||||
|
||||
17
network.py
Normal file
17
network.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TetrisRLModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features=209, out_features=512)
|
||||
self.fc2 = nn.Linear(in_features=512, out_features=32)
|
||||
self.fc3 = nn.Linear(in_features=32, out_features=10)
|
||||
|
||||
def forward(self, observation):
|
||||
observation = observation.to(dtype=torch.float32)
|
||||
observation = torch.relu(self.fc1(observation))
|
||||
observation = torch.relu(self.fc2(observation))
|
||||
observation = F.softmax(self.fc3(observation), dim=0)
|
||||
return observation
|
||||
Loading…
Reference in New Issue
Block a user