Compare commits
No commits in common. "d0a3d80b777073b5ec5bf7e0bea179054661c45a" and "175b674f98a18e2ca1d42a93c111c9bff5023010" have entirely different histories.
d0a3d80b77
...
175b674f98
31
pacman.py
31
pacman.py
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from rl_algorithm import DQN
|
from rl_algorithm import DQN
|
||||||
from custom_env import ImageEnv
|
from custom_env import ImageEnv
|
||||||
from utils import seed_everything, YOUR_CODE_HERE, plot
|
from utils import seed_everything, YOUR_CODE_HERE
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -22,17 +22,17 @@ def parse_args():
|
|||||||
parser.add_argument('--image_hw', type=int, default=84, help='The height and width of the image')
|
parser.add_argument('--image_hw', type=int, default=84, help='The height and width of the image')
|
||||||
parser.add_argument('--num_envs', type=int, default=4)
|
parser.add_argument('--num_envs', type=int, default=4)
|
||||||
# DQN hyperparameters
|
# DQN hyperparameters
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
parser.add_argument('--epsilon', type=float, default=0.9)
|
parser.add_argument('--epsilon', type=float, default=0.9)
|
||||||
parser.add_argument('--epsilon_min', type=float, default=0.05)
|
parser.add_argument('--epsilon_min', type=float, default=0.05)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--batch_size', type=int, default=64)
|
parser.add_argument('--batch_size', type=int, default=64)
|
||||||
parser.add_argument('--warmup_steps', type=int, default=1000)
|
parser.add_argument('--warmup_steps', type=int, default=5000)
|
||||||
parser.add_argument('--buffer_size', type=int, default=int(1e5))
|
parser.add_argument('--buffer_size', type=int, default=int(1e5))
|
||||||
parser.add_argument('--target_update_interval', type=int, default=10000)
|
parser.add_argument('--target_update_interval', type=int, default=10000)
|
||||||
# training hyperparameters
|
# training hyperparameters
|
||||||
parser.add_argument('--max_steps', type=int, default=int(2e5))
|
parser.add_argument('--max_steps', type=int, default=int(2.5e5))
|
||||||
parser.add_argument('--eval_interval', type=int, default=5000)
|
parser.add_argument('--eval_interval', type=int, default=10000)
|
||||||
# others
|
# others
|
||||||
parser.add_argument('--save_root', type=Path, default='./submissions')
|
parser.add_argument('--save_root', type=Path, default='./submissions')
|
||||||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
|
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
|
||||||
@ -51,10 +51,11 @@ def validation(agent, num_evals=5):
|
|||||||
(state, _), done = eval_env.reset(), False
|
(state, _), done = eval_env.reset(), False
|
||||||
while not done:
|
while not done:
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
|
utils.raiseNotDefined()
|
||||||
# do action from your agent
|
# do action from your agent
|
||||||
action = agent.act(state, training=False)
|
action = YOUR_CODE_HERE
|
||||||
# get your action feedback from environment
|
# get your action feedback from environment
|
||||||
next_state, reward, terminated, truncated, info = eval_env.step(action)
|
next_state, reward, terminated, truncated, info = YOUR_CODE_HERE
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
scores += reward
|
scores += reward
|
||||||
@ -62,7 +63,7 @@ def validation(agent, num_evals=5):
|
|||||||
return np.round(scores / num_evals, 4)
|
return np.round(scores / num_evals, 4)
|
||||||
|
|
||||||
def train(agent, env):
|
def train(agent, env):
|
||||||
history = {'Step': [], 'AvgScore': [], 'value_loss': []}
|
history = {'Step': [], 'AvgScore': []}
|
||||||
|
|
||||||
(state, _) = env.reset()
|
(state, _) = env.reset()
|
||||||
|
|
||||||
@ -80,11 +81,9 @@ def train(agent, env):
|
|||||||
avg_score = validation(agent)
|
avg_score = validation(agent)
|
||||||
history['Step'].append(agent.total_steps)
|
history['Step'].append(agent.total_steps)
|
||||||
history['AvgScore'].append(avg_score)
|
history['AvgScore'].append(avg_score)
|
||||||
history['value_loss'].append(result['value_loss'])
|
|
||||||
|
|
||||||
# log info to plot your figure
|
# log info to plot your figure
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
plot(history['Step'], history['AvgScore'], history['value_loss'], 'output.png')
|
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
torch.save(agent.network.state_dict(), save_dir / 'pacma_dqn.pt')
|
torch.save(agent.network.state_dict(), save_dir / 'pacma_dqn.pt')
|
||||||
@ -130,17 +129,7 @@ def main():
|
|||||||
state_dim = (args.num_envs, args.image_hw, args.image_hw)
|
state_dim = (args.num_envs, args.image_hw, args.image_hw)
|
||||||
print(action_dim)
|
print(action_dim)
|
||||||
print(state_dim)
|
print(state_dim)
|
||||||
agent = DQN(state_dim=state_dim, action_dim=action_dim,
|
agent = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||||
lr=args.lr,
|
|
||||||
epsilon=args.epsilon,
|
|
||||||
epsilon_min=args.epsilon_min,
|
|
||||||
gamma=args.gamma,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
warmup_steps=args.warmup_steps,
|
|
||||||
buffer_size=int(args.buffer_size),
|
|
||||||
target_update_interval=args.target_update_interval,
|
|
||||||
)
|
|
||||||
print(agent)
|
|
||||||
|
|
||||||
# train
|
# train
|
||||||
train(agent, env)
|
train(agent, env)
|
||||||
|
|||||||
@ -12,12 +12,15 @@ class PacmanActionCNN(nn.Module):
|
|||||||
# build your own CNN model
|
# build your own CNN model
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
# this is just an example, you can modify this.
|
# this is just an example, you can modify this.
|
||||||
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1)
|
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1, padding='same')
|
||||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
|
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding='same')
|
||||||
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
|
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding='same')
|
||||||
self.fc1 = nn.Linear(in_features=3136, out_features=action_dim)
|
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same')
|
||||||
|
self.fc1 = nn.Linear(in_features=3200, out_features=512)
|
||||||
|
self.fc2 = nn.Linear(in_features=512, out_features=64)
|
||||||
|
self.fc3 = nn.Linear(in_features=64, out_features=action_dim)
|
||||||
|
|
||||||
self.pooling = nn.MaxPool2d(kernel_size=3, stride=2)
|
self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
@ -25,15 +28,18 @@ class PacmanActionCNN(nn.Module):
|
|||||||
|
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
x = self.relu(self.conv1(x))
|
x = self.relu(self.conv1(x))
|
||||||
x = self.pooling(x)
|
|
||||||
x = self.relu(self.conv2(x))
|
x = self.relu(self.conv2(x))
|
||||||
x = self.pooling(x)
|
x = self.pooling(x)
|
||||||
x = self.relu(self.conv3(x))
|
x = self.relu(self.conv3(x))
|
||||||
x = self.pooling(x)
|
x = self.pooling(x)
|
||||||
|
x = self.relu(self.conv4(x))
|
||||||
|
x = self.pooling(x)
|
||||||
|
|
||||||
x = self.flatten(x)
|
x = self.flatten(x)
|
||||||
|
|
||||||
x = self.fc1(x)
|
x = self.relu(self.fc1(x))
|
||||||
|
x = self.relu(self.fc2(x))
|
||||||
|
x = self.fc3(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -121,7 +127,6 @@ class DQN:
|
|||||||
else:
|
else:
|
||||||
# output actions by following epsilon-greedy policy
|
# output actions by following epsilon-greedy policy
|
||||||
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
|
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
# utils.raiseNotDefined()
|
# utils.raiseNotDefined()
|
||||||
@ -146,7 +151,7 @@ class DQN:
|
|||||||
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
|
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
|
||||||
td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
|
td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
|
||||||
# compute loss with td_target and q-values
|
# compute loss with td_target and q-values
|
||||||
criterion = nn.SmoothL1Loss()
|
criterion = nn.MSELoss()
|
||||||
loss = criterion(pred_q, td_target)
|
loss = criterion(pred_q, td_target)
|
||||||
|
|
||||||
# initialize optimizer
|
# initialize optimizer
|
||||||
@ -175,6 +180,5 @@ class DQN:
|
|||||||
# update target networ
|
# update target networ
|
||||||
self.target_network.load_state_dict(self.network.state_dict())
|
self.target_network.load_state_dict(self.network.state_dict())
|
||||||
|
|
||||||
# self.epsilon -= self.epsilon_decay
|
self.epsilon -= self.epsilon_decay
|
||||||
self.epsilon *= 0.95
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
Binary file not shown.
22
utils.py
22
utils.py
@ -6,8 +6,6 @@ import torch
|
|||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
def raiseNotDefined():
|
def raiseNotDefined():
|
||||||
filename = inspect.stack()[1][1]
|
filename = inspect.stack()[1][1]
|
||||||
line = inspect.stack()[1][2]
|
line = inspect.stack()[1][2]
|
||||||
@ -24,23 +22,5 @@ def seed_everything(seed, env):
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
|
|
||||||
def plot(steps, avg_scores, value_loss, img_name='output.png'):
|
|
||||||
fig, ax1 = plt.subplots()
|
|
||||||
plt.xlabel('Steps')
|
|
||||||
ax2 = ax1.twinx()
|
|
||||||
|
|
||||||
ax1.set_ylabel('AvgScores', color='tab:red')
|
|
||||||
ax1.plot(steps, avg_scores, color='tab:red', alpha=0.75)
|
|
||||||
ax1.tick_params(axis='y', labelcolor='tab:red')
|
|
||||||
|
|
||||||
ax2.set_ylabel('ValueLoss', color='tab:blue')
|
|
||||||
ax2.plot(steps, value_loss, color='tab:blue', alpha=1)
|
|
||||||
ax2.tick_params(axis='y', labelcolor='tab:blue')
|
|
||||||
|
|
||||||
fig.tight_layout()
|
|
||||||
plt.savefig(img_name)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
YOUR_CODE_HERE = "*** YOUR CODE HERE ***"
|
YOUR_CODE_HERE = "*** YOUR CODE HERE ***"
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user