fix: change model & hyper parameter

This commit is contained in:
Ting-Jun Wang 2024-05-24 03:04:10 +08:00
parent 175b674f98
commit 970809f3fb
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 54 additions and 27 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from rl_algorithm import DQN from rl_algorithm import DQN
from custom_env import ImageEnv from custom_env import ImageEnv
from utils import seed_everything, YOUR_CODE_HERE from utils import seed_everything, YOUR_CODE_HERE, plot
import utils import utils
def parse_args(): def parse_args():
@ -22,17 +22,17 @@ def parse_args():
parser.add_argument('--image_hw', type=int, default=84, help='The height and width of the image') parser.add_argument('--image_hw', type=int, default=84, help='The height and width of the image')
parser.add_argument('--num_envs', type=int, default=4) parser.add_argument('--num_envs', type=int, default=4)
# DQN hyperparameters # DQN hyperparameters
parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--epsilon', type=float, default=0.9) parser.add_argument('--epsilon', type=float, default=0.9)
parser.add_argument('--epsilon_min', type=float, default=0.05) parser.add_argument('--epsilon_min', type=float, default=0.05)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--warmup_steps', type=int, default=5000) parser.add_argument('--warmup_steps', type=int, default=1000)
parser.add_argument('--buffer_size', type=int, default=int(1e5)) parser.add_argument('--buffer_size', type=int, default=int(1e5))
parser.add_argument('--target_update_interval', type=int, default=10000) parser.add_argument('--target_update_interval', type=int, default=10000)
# training hyperparameters # training hyperparameters
parser.add_argument('--max_steps', type=int, default=int(2.5e5)) parser.add_argument('--max_steps', type=int, default=int(2e5))
parser.add_argument('--eval_interval', type=int, default=10000) parser.add_argument('--eval_interval', type=int, default=5000)
# others # others
parser.add_argument('--save_root', type=Path, default='./submissions') parser.add_argument('--save_root', type=Path, default='./submissions')
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
@ -51,11 +51,10 @@ def validation(agent, num_evals=5):
(state, _), done = eval_env.reset(), False (state, _), done = eval_env.reset(), False
while not done: while not done:
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
utils.raiseNotDefined()
# do action from your agent # do action from your agent
action = YOUR_CODE_HERE action = agent.act(state, training=False)
# get your action feedback from environment # get your action feedback from environment
next_state, reward, terminated, truncated, info = YOUR_CODE_HERE next_state, reward, terminated, truncated, info = eval_env.step(action)
state = next_state state = next_state
scores += reward scores += reward
@ -63,7 +62,7 @@ def validation(agent, num_evals=5):
return np.round(scores / num_evals, 4) return np.round(scores / num_evals, 4)
def train(agent, env): def train(agent, env):
history = {'Step': [], 'AvgScore': []} history = {'Step': [], 'AvgScore': [], 'value_loss': []}
(state, _) = env.reset() (state, _) = env.reset()
@ -81,9 +80,11 @@ def train(agent, env):
avg_score = validation(agent) avg_score = validation(agent)
history['Step'].append(agent.total_steps) history['Step'].append(agent.total_steps)
history['AvgScore'].append(avg_score) history['AvgScore'].append(avg_score)
history['value_loss'].append(result['value_loss'])
# log info to plot your figure # log info to plot your figure
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
plot(history['Step'], history['AvgScore'], history['value_loss'], 'output.png')
# save model # save model
torch.save(agent.network.state_dict(), save_dir / 'pacma_dqn.pt') torch.save(agent.network.state_dict(), save_dir / 'pacma_dqn.pt')
@ -129,7 +130,17 @@ def main():
state_dim = (args.num_envs, args.image_hw, args.image_hw) state_dim = (args.num_envs, args.image_hw, args.image_hw)
print(action_dim) print(action_dim)
print(state_dim) print(state_dim)
agent = DQN(state_dim=state_dim, action_dim=action_dim) agent = DQN(state_dim=state_dim, action_dim=action_dim,
lr=args.lr,
epsilon=args.epsilon,
epsilon_min=args.epsilon_min,
gamma=args.gamma,
batch_size=args.batch_size,
warmup_steps=args.warmup_steps,
buffer_size=int(args.buffer_size),
target_update_interval=args.target_update_interval,
)
print(agent)
# train # train
train(agent, env) train(agent, env)

View File

@ -12,15 +12,12 @@ class PacmanActionCNN(nn.Module):
# build your own CNN model # build your own CNN model
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
# this is just an example, you can modify this. # this is just an example, you can modify this.
self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1, padding='same') self.conv1 = nn.Conv2d(state_dim, 16, kernel_size=3, stride=1)
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding='same') self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding='same') self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same') self.fc1 = nn.Linear(in_features=3136, out_features=action_dim)
self.fc1 = nn.Linear(in_features=3200, out_features=512)
self.fc2 = nn.Linear(in_features=512, out_features=64)
self.fc3 = nn.Linear(in_features=64, out_features=action_dim)
self.pooling = nn.MaxPool2d(kernel_size=2, stride=2) self.pooling = nn.MaxPool2d(kernel_size=3, stride=2)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
@ -28,18 +25,15 @@ class PacmanActionCNN(nn.Module):
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
x = self.relu(self.conv1(x)) x = self.relu(self.conv1(x))
x = self.pooling(x)
x = self.relu(self.conv2(x)) x = self.relu(self.conv2(x))
x = self.pooling(x) x = self.pooling(x)
x = self.relu(self.conv3(x)) x = self.relu(self.conv3(x))
x = self.pooling(x) x = self.pooling(x)
x = self.relu(self.conv4(x))
x = self.pooling(x)
x = self.flatten(x) x = self.flatten(x)
x = self.relu(self.fc1(x)) x = self.fc1(x)
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x return x
@ -128,6 +122,7 @@ class DQN:
# output actions by following epsilon-greedy policy # output actions by following epsilon-greedy policy
x = torch.from_numpy(x).float().unsqueeze(0).to(self.device) x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
# utils.raiseNotDefined() # utils.raiseNotDefined()
# get q-values from network # get q-values from network
@ -151,7 +146,7 @@ class DQN:
# td_target: if terminated, only reward, otherwise reward + gamma * max(next_q) # td_target: if terminated, only reward, otherwise reward + gamma * max(next_q)
td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max()) td_target = torch.where(terminated, reward, reward + self.gamma * next_q.max())
# compute loss with td_target and q-values # compute loss with td_target and q-values
criterion = nn.MSELoss() criterion = nn.SmoothL1Loss()
loss = criterion(pred_q, td_target) loss = criterion(pred_q, td_target)
# initialize optimizer # initialize optimizer
@ -180,5 +175,6 @@ class DQN:
# update target networ # update target networ
self.target_network.load_state_dict(self.network.state_dict()) self.target_network.load_state_dict(self.network.state_dict())
self.epsilon -= self.epsilon_decay # self.epsilon -= self.epsilon_decay
self.epsilon *= 0.95
return result return result

View File

@ -6,6 +6,8 @@ import torch
import random import random
import numpy as np import numpy as np
import matplotlib.pyplot as plt
def raiseNotDefined(): def raiseNotDefined():
filename = inspect.stack()[1][1] filename = inspect.stack()[1][1]
line = inspect.stack()[1][2] line = inspect.stack()[1][2]
@ -23,4 +25,22 @@ def seed_everything(seed, env):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
env.seed(seed) env.seed(seed)
def plot(steps, avg_scores, value_loss, img_name='output.png'):
fig, ax1 = plt.subplots()
plt.xlabel('Steps')
ax2 = ax1.twinx()
ax1.set_ylabel('AvgScores', color='tab:red')
ax1.plot(steps, avg_scores, color='tab:red', alpha=0.75)
ax1.tick_params(axis='y', labelcolor='tab:red')
ax2.set_ylabel('ValueLoss', color='tab:blue')
ax2.plot(steps, value_loss, color='tab:blue', alpha=1)
ax2.tick_params(axis='y', labelcolor='tab:blue')
fig.tight_layout()
plt.savefig(img_name)
plt.close()
YOUR_CODE_HERE = "*** YOUR CODE HERE ***" YOUR_CODE_HERE = "*** YOUR CODE HERE ***"