96 lines
2.7 KiB
Python
96 lines
2.7 KiB
Python
import torch
|
|
from torch.optim import Adam
|
|
from torch.distributions import Categorical
|
|
from zmq import device
|
|
from enviroment import TetrisEnviroment
|
|
from network import TetrisRLModel
|
|
import torch.nn.functional as F
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print("use:", device)
|
|
|
|
PATH = "model.h5"
|
|
|
|
class TetrisRLAgent():
|
|
def __init__(self) -> None:
|
|
self.model = TetrisRLModel()
|
|
self.model = self.model.to(device)
|
|
self.optim = Adam(self.model.parameters(), lr=0.001)
|
|
|
|
def sample(self, observation):
|
|
action_prob = self.model(observation)
|
|
action_dist = Categorical(action_prob)
|
|
action = action_dist.sample()
|
|
log_prob = action_dist.log_prob(action)
|
|
return action, log_prob
|
|
|
|
def learn(self, rewards, log_probes):
|
|
rewards = rewards.to(device)
|
|
loss = ((-log_probes * rewards)).sum()
|
|
|
|
self.optim.zero_grad()
|
|
loss.backward()
|
|
self.optim.step()
|
|
|
|
def save(self, PATH): # You should not revise this
|
|
Agent_Dict = {
|
|
"network" : self.network.state_dict(),
|
|
"optimizer" : self.optimizer.state_dict()
|
|
}
|
|
torch.save(Agent_Dict, PATH)
|
|
|
|
def showGame(views:list, score:int) -> None:
|
|
for i in range(20):
|
|
print(str(i).rjust(2), end=' ')
|
|
for j in range(10):
|
|
if views[i][j]:
|
|
print('■', end='')
|
|
else:
|
|
print('□', end='')
|
|
print()
|
|
print("Score:", score)
|
|
print()
|
|
|
|
env = TetrisEnviroment()
|
|
agent = TetrisRLAgent()
|
|
|
|
avgTotalRewards = []
|
|
for batch in range(100000):
|
|
|
|
log_probs, rewards = [], []
|
|
total_rewards = []
|
|
|
|
for episode in range(5):
|
|
total_reward = 0
|
|
|
|
pixel, reward, done, info = env.reset()
|
|
while 1:
|
|
showGame(pixel, total_reward)
|
|
|
|
pixel = torch.tensor(pixel, dtype=torch.float32)
|
|
blockType = F.one_hot(torch.tensor(info[2]), 7)
|
|
blockLoc = torch.tensor(info[:2], dtype=torch.float32)
|
|
observation = torch.cat([pixel.reshape(-1), blockLoc, blockType], dim=0)
|
|
observation = observation.to(device)
|
|
|
|
action, log_prob = agent.sample(observation)
|
|
pixel, reward, done, info = env.step(action)
|
|
rewards.append(reward)
|
|
log_probs.append(log_prob)
|
|
total_reward += reward
|
|
if done:
|
|
total_rewards.append(total_reward)
|
|
break
|
|
|
|
avgTotalReward = sum(total_rewards) / len(total_rewards)
|
|
avgTotalRewards.append(avgTotalReward)
|
|
|
|
rewards = torch.tensor(rewards)
|
|
log_probs = torch.stack(log_probs)
|
|
|
|
print(rewards.shape)
|
|
print(log_probs.shape)
|
|
|
|
agent.learn(rewards, log_probs)
|
|
|
|
agent.save(PATH) |