feat: complete DQN
This commit is contained in:
parent
18b0cb34d6
commit
175b674f98
@ -127,6 +127,8 @@ def main():
|
|||||||
|
|
||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
state_dim = (args.num_envs, args.image_hw, args.image_hw)
|
state_dim = (args.num_envs, args.image_hw, args.image_hw)
|
||||||
|
print(action_dim)
|
||||||
|
print(state_dim)
|
||||||
agent = DQN(state_dim=state_dim, action_dim=action_dim)
|
agent = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
@ -151,4 +153,4 @@ if __name__ == "__main__":
|
|||||||
evaluate(agent=None, eval_env=eval_env, capture_frames=False)
|
evaluate(agent=None, eval_env=eval_env, capture_frames=False)
|
||||||
else:
|
else:
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@ -11,15 +11,35 @@ class PacmanActionCNN(nn.Module):
|
|||||||
super(PacmanActionCNN, self).__init__()
|
super(PacmanActionCNN, self).__init__()
|
||||||
# build your own CNN model
|
# build your own CNN model
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
utils.raiseNotDefined()
|
|
||||||
# this is just an example, you can modify this.
|
# this is just an example, you can modify this.
|
||||||
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=8, stride=4)
|
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1, padding='same')
|
||||||
|
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding='same')
|
||||||
|
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding='same')
|
||||||
|
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same')
|
||||||
|
self.fc1 = nn.Linear(in_features=3200, out_features=512)
|
||||||
|
self.fc2 = nn.Linear(in_features=512, out_features=64)
|
||||||
|
self.fc3 = nn.Linear(in_features=64, out_features=action_dim)
|
||||||
|
|
||||||
|
self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.relu(self.conv1(x))
|
|
||||||
|
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
utils.raiseNotDefined()
|
x = self.relu(self.conv1(x))
|
||||||
|
x = self.relu(self.conv2(x))
|
||||||
|
x = self.pooling(x)
|
||||||
|
x = self.relu(self.conv3(x))
|
||||||
|
x = self.pooling(x)
|
||||||
|
x = self.relu(self.conv4(x))
|
||||||
|
x = self.pooling(x)
|
||||||
|
|
||||||
|
x = self.flatten(x)
|
||||||
|
|
||||||
|
x = self.relu(self.fc1(x))
|
||||||
|
x = self.relu(self.fc2(x))
|
||||||
|
x = self.fc3(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -109,53 +129,56 @@ class DQN:
|
|||||||
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
|
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
utils.raiseNotDefined()
|
# utils.raiseNotDefined()
|
||||||
# get q-values from network
|
# get q-values from network
|
||||||
q_value = YOUR_CODE_HERE
|
q_value = self.network(x).squeeze()
|
||||||
# get action with maximum q-value
|
# get action with maximum q-value
|
||||||
action = YOUR_CODE_HERE
|
action = q_value.argmax().cpu()
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
utils.raiseNotDefined()
|
|
||||||
|
|
||||||
# sample a mini-batch from replay buffer
|
# sample a mini-batch from replay buffer
|
||||||
state, action, reward, next_state, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size))
|
state, action, reward, next_state, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size))
|
||||||
|
action = action.to(torch.int64)
|
||||||
|
terminated = terminated.bool()
|
||||||
|
|
||||||
# get q-values from network
|
# get q-values from network
|
||||||
next_q = YOUR_CODE_HERE
|
pred_q = self.network(state).gather(1, action)
|
||||||
|
next_q = self.target_network(next_state)
|
||||||
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
|
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
|
||||||
td_target = YOUR_CODE_HERE
|
td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
|
||||||
# compute loss with td_target and q-values
|
# compute loss with td_target and q-values
|
||||||
loss = YOUR_CODE_HERE
|
criterion = nn.MSELoss()
|
||||||
|
loss = criterion(pred_q, td_target)
|
||||||
|
|
||||||
# initialize optimizer
|
# initialize optimizer
|
||||||
"self.optimizer.YOUR_CODE_HERE"
|
self.optimizer.zero_grad()
|
||||||
# backpropagation
|
# backpropagation
|
||||||
YOUR_CODE_HERE
|
loss.backward()
|
||||||
# update network
|
# update network
|
||||||
"self.optimizer.YOUR_CODE_HERE"
|
self.optimizer.step()
|
||||||
|
|
||||||
return {YOUR_CODE_HERE} # return dictionary for logging
|
return {'value_loss': loss.item()} # return dictionary for logging
|
||||||
|
|
||||||
def process(self, transition):
|
def process(self, transition):
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
utils.raiseNotDefined()
|
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
|
|
||||||
# update replay buffer
|
# update replay buffer
|
||||||
"self.buffer.YOUR_CODE_HERE"
|
"self.buffer.YOUR_CODE_HERE"
|
||||||
|
self.buffer.update(*transition)
|
||||||
|
|
||||||
if self.total_steps > self.warmup_steps:
|
if self.total_steps > self.warmup_steps:
|
||||||
result = self.learn()
|
result = self.learn()
|
||||||
|
|
||||||
if self.total_steps % self.target_update_interval == 0:
|
if self.total_steps % self.target_update_interval == 0:
|
||||||
# update target networ
|
# update target networ
|
||||||
"self.target_network.YOUR_CODE_HERE"
|
self.target_network.load_state_dict(self.network.state_dict())
|
||||||
|
|
||||||
self.epsilon -= self.epsilon_decay
|
self.epsilon -= self.epsilon_decay
|
||||||
return result
|
return result
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user