feat: complete DQN

This commit is contained in:
snsd0805 2024-05-22 01:48:44 +08:00
parent 18b0cb34d6
commit 175b674f98
Signed by: snsd0805
GPG Key ID: 569349933C77A854
2 changed files with 45 additions and 20 deletions

View File

@ -127,6 +127,8 @@ def main():
action_dim = env.action_space.n action_dim = env.action_space.n
state_dim = (args.num_envs, args.image_hw, args.image_hw) state_dim = (args.num_envs, args.image_hw, args.image_hw)
print(action_dim)
print(state_dim)
agent = DQN(state_dim=state_dim, action_dim=action_dim) agent = DQN(state_dim=state_dim, action_dim=action_dim)
# train # train

View File

@ -11,15 +11,35 @@ class PacmanActionCNN(nn.Module):
super(PacmanActionCNN, self).__init__() super(PacmanActionCNN, self).__init__()
# build your own CNN model # build your own CNN model
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined()
# this is just an example, you can modify this. # this is just an example, you can modify this.
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=8, stride=4) self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1, padding='same')
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding='same')
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding='same')
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same')
self.fc1 = nn.Linear(in_features=3200, out_features=512)
self.fc2 = nn.Linear(in_features=512, out_features=64)
self.fc3 = nn.Linear(in_features=64, out_features=action_dim)
self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x))
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined() x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pooling(x)
x = self.relu(self.conv3(x))
x = self.pooling(x)
x = self.relu(self.conv4(x))
x = self.pooling(x)
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x return x
@ -109,53 +129,56 @@ class DQN:
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device) x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined() # utils.raiseNotDefined()
# get q-values from network # get q-values from network
q_value = YOUR_CODE_HERE q_value = self.network(x).squeeze()
# get action with maximum q-value # get action with maximum q-value
action = YOUR_CODE_HERE action = q_value.argmax().cpu()
return action return action
def learn(self): def learn(self):
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined()
# sample a mini-batch from replay buffer # sample a mini-batch from replay buffer
state, action, reward, next_state, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size)) state, action, reward, next_state, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size))
action = action.to(torch.int64)
terminated = terminated.bool()
# get q-values from network # get q-values from network
next_q = YOUR_CODE_HERE pred_q = self.network(state).gather(1, action)
next_q = self.target_network(next_state)
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q) # td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
td_target = YOUR_CODE_HERE td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
# compute loss with td_target and q-values # compute loss with td_target and q-values
loss = YOUR_CODE_HERE criterion = nn.MSELoss()
loss = criterion(pred_q, td_target)
# initialize optimizer # initialize optimizer
"self.optimizer.YOUR_CODE_HERE" self.optimizer.zero_grad()
# backpropagation # backpropagation
YOUR_CODE_HERE loss.backward()
# update network # update network
"self.optimizer.YOUR_CODE_HERE" self.optimizer.step()
return {YOUR_CODE_HERE} # return dictionary for logging return {'value_loss': loss.item()} # return dictionary for logging
def process(self, transition): def process(self, transition):
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined()
result = {} result = {}
self.total_steps += 1 self.total_steps += 1
# update replay buffer # update replay buffer
"self.buffer.YOUR_CODE_HERE" "self.buffer.YOUR_CODE_HERE"
self.buffer.update(*transition)
if self.total_steps > self.warmup_steps: if self.total_steps > self.warmup_steps:
result = self.learn() result = self.learn()
if self.total_steps % self.target_update_interval == 0: if self.total_steps % self.target_update_interval == 0:
# update target networ # update target networ
"self.target_network.YOUR_CODE_HERE" self.target_network.load_state_dict(self.network.state_dict())
self.epsilon -= self.epsilon_decay self.epsilon -= self.epsilon_decay
return result return result