feat: complete DQN
This commit is contained in:
parent
18b0cb34d6
commit
175b674f98
@ -127,6 +127,8 @@ def main():
|
||||
|
||||
action_dim = env.action_space.n
|
||||
state_dim = (args.num_envs, args.image_hw, args.image_hw)
|
||||
print(action_dim)
|
||||
print(state_dim)
|
||||
agent = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||
|
||||
# train
|
||||
@ -151,4 +153,4 @@ if __name__ == "__main__":
|
||||
evaluate(agent=None, eval_env=eval_env, capture_frames=False)
|
||||
else:
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@ -11,15 +11,35 @@ class PacmanActionCNN(nn.Module):
|
||||
super(PacmanActionCNN, self).__init__()
|
||||
# build your own CNN model
|
||||
"*** YOUR CODE HERE ***"
|
||||
utils.raiseNotDefined()
|
||||
# this is just an example, you can modify this.
|
||||
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=8, stride=4)
|
||||
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1, padding='same')
|
||||
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding='same')
|
||||
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding='same')
|
||||
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same')
|
||||
self.fc1 = nn.Linear(in_features=3200, out_features=512)
|
||||
self.fc2 = nn.Linear(in_features=512, out_features=64)
|
||||
self.fc3 = nn.Linear(in_features=64, out_features=action_dim)
|
||||
|
||||
self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
|
||||
|
||||
"*** YOUR CODE HERE ***"
|
||||
utils.raiseNotDefined()
|
||||
x = self.relu(self.conv1(x))
|
||||
x = self.relu(self.conv2(x))
|
||||
x = self.pooling(x)
|
||||
x = self.relu(self.conv3(x))
|
||||
x = self.pooling(x)
|
||||
x = self.relu(self.conv4(x))
|
||||
x = self.pooling(x)
|
||||
|
||||
x = self.flatten(x)
|
||||
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
|
||||
return x
|
||||
|
||||
@ -109,53 +129,56 @@ class DQN:
|
||||
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
|
||||
|
||||
"*** YOUR CODE HERE ***"
|
||||
utils.raiseNotDefined()
|
||||
# utils.raiseNotDefined()
|
||||
# get q-values from network
|
||||
q_value = YOUR_CODE_HERE
|
||||
q_value = self.network(x).squeeze()
|
||||
# get action with maximum q-value
|
||||
action = YOUR_CODE_HERE
|
||||
action = q_value.argmax().cpu()
|
||||
|
||||
return action
|
||||
|
||||
def learn(self):
|
||||
"*** YOUR CODE HERE ***"
|
||||
utils.raiseNotDefined()
|
||||
|
||||
# sample a mini-batch from replay buffer
|
||||
state, action, reward, next_state, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size))
|
||||
action = action.to(torch.int64)
|
||||
terminated = terminated.bool()
|
||||
|
||||
# get q-values from network
|
||||
next_q = YOUR_CODE_HERE
|
||||
pred_q = self.network(state).gather(1, action)
|
||||
next_q = self.target_network(next_state)
|
||||
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
|
||||
td_target = YOUR_CODE_HERE
|
||||
td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
|
||||
# compute loss with td_target and q-values
|
||||
loss = YOUR_CODE_HERE
|
||||
criterion = nn.MSELoss()
|
||||
loss = criterion(pred_q, td_target)
|
||||
|
||||
# initialize optimizer
|
||||
"self.optimizer.YOUR_CODE_HERE"
|
||||
self.optimizer.zero_grad()
|
||||
# backpropagation
|
||||
YOUR_CODE_HERE
|
||||
loss.backward()
|
||||
# update network
|
||||
"self.optimizer.YOUR_CODE_HERE"
|
||||
self.optimizer.step()
|
||||
|
||||
return {YOUR_CODE_HERE} # return dictionary for logging
|
||||
return {'value_loss': loss.item()} # return dictionary for logging
|
||||
|
||||
def process(self, transition):
|
||||
"*** YOUR CODE HERE ***"
|
||||
utils.raiseNotDefined()
|
||||
|
||||
result = {}
|
||||
self.total_steps += 1
|
||||
|
||||
# update replay buffer
|
||||
"self.buffer.YOUR_CODE_HERE"
|
||||
self.buffer.update(*transition)
|
||||
|
||||
if self.total_steps > self.warmup_steps:
|
||||
result = self.learn()
|
||||
|
||||
if self.total_steps % self.target_update_interval == 0:
|
||||
# update target networ
|
||||
"self.target_network.YOUR_CODE_HERE"
|
||||
self.target_network.load_state_dict(self.network.state_dict())
|
||||
|
||||
self.epsilon -= self.epsilon_decay
|
||||
return result
|
||||
return result
|
||||
|
||||
Loading…
Reference in New Issue
Block a user