NTU-AI-HW5/pacman.py
2024-05-22 01:48:44 +08:00

157 lines
5.6 KiB
Python

import os
import time
import argparse
from pathlib import Path
import numpy as np
import gymnasium as gym
import torch
import imageio
from tqdm import tqdm
from rl_algorithm import DQN
from custom_env import ImageEnv
from utils import seed_everything, YOUR_CODE_HERE
import utils
def parse_args():
parser = argparse.ArgumentParser()
# environment hyperparameters
parser.add_argument('--env_name', type=str, default='ALE/MsPacman-v5')
parser.add_argument('--state_dim', type=tuple, default=(4, 84, 84))
parser.add_argument('--image_hw', type=int, default=84, help='The height and width of the image')
parser.add_argument('--num_envs', type=int, default=4)
# DQN hyperparameters
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--epsilon', type=float, default=0.9)
parser.add_argument('--epsilon_min', type=float, default=0.05)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--warmup_steps', type=int, default=5000)
parser.add_argument('--buffer_size', type=int, default=int(1e5))
parser.add_argument('--target_update_interval', type=int, default=10000)
# training hyperparameters
parser.add_argument('--max_steps', type=int, default=int(2.5e5))
parser.add_argument('--eval_interval', type=int, default=10000)
# others
parser.add_argument('--save_root', type=Path, default='./submissions')
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
# evaluation
parser.add_argument('--eval', action="store_true", help='evaluate the model')
parser.add_argument('--eval_model_path', type=str, default=None, help='the path of the model to evaluate')
return parser.parse_args()
def validation(agent, num_evals=5):
eval_env = gym.make('ALE/MsPacman-v5')
eval_env = ImageEnv(eval_env)
scores = 0
for i in range(num_evals):
(state, _), done = eval_env.reset(), False
while not done:
"*** YOUR CODE HERE ***"
utils.raiseNotDefined()
# do action from your agent
action = YOUR_CODE_HERE
# get your action feedback from environment
next_state, reward, terminated, truncated, info = YOUR_CODE_HERE
state = next_state
scores += reward
done = terminated or truncated
return np.round(scores / num_evals, 4)
def train(agent, env):
history = {'Step': [], 'AvgScore': []}
(state, _) = env.reset()
for _ in tqdm(range(args.max_steps)):
action = agent.act(state)
next_state, reward, terminated, truncated, _ = env.step(action)
result = agent.process((state, action, reward, next_state, terminated)) # You can track q-losses over training from `result` variable.
state = next_state
if terminated or truncated:
state, _ = env.reset()
if agent.total_steps % args.eval_interval == 0:
avg_score = validation(agent)
history['Step'].append(agent.total_steps)
history['AvgScore'].append(avg_score)
# log info to plot your figure
"*** YOUR CODE HERE ***"
# save model
torch.save(agent.network.state_dict(), save_dir / 'pacma_dqn.pt')
print("Step: {}, AvgScore: {}, ValueLoss: {}".format(agent.total_steps, avg_score, result["value_loss"]))
def evaluate(agent, eval_env, capture_frames=True):
seed_everything(0, eval_env) # don't modify
# load the model
if agent is None:
action_dim = eval_env.action_space.n
state_dim = (args.num_envs, args.image_hw, args.image_hw)
agent = DQN(state_dim=state_dim, action_dim=action_dim)
agent.network.load_state_dict(torch.load(args.eval_model_path))
(state, _), done = eval_env.reset(), False
scores = 0
# Record the frames
if capture_frames:
writer = imageio.get_writer(save_dir / 'mspacman.mp4', fps=10)
while not done:
if capture_frames:
writer.append_data(eval_env.render())
else:
eval_env.render()
action = agent.act(state, training=False)
next_state, reward, terminated, truncated, info = eval_env.step(action)
state = next_state
scores += reward
done = terminated or truncated
if capture_frames:
writer.close()
print("The score of the agent: ", scores)
def main():
env = gym.make(args.env_name)
env = ImageEnv(env, stack_frames=args.num_envs, image_hw=args.image_hw)
action_dim = env.action_space.n
state_dim = (args.num_envs, args.image_hw, args.image_hw)
print(action_dim)
print(state_dim)
agent = DQN(state_dim=state_dim, action_dim=action_dim)
# train
train(agent, env)
# evaluate
eval_env = gym.make(args.env_name, render_mode='rgb_array')
eval_env = ImageEnv(eval_env, stack_frames=args.num_envs, image_hw=args.image_hw)
evaluate(agent, eval_env)
if __name__ == "__main__":
args = parse_args()
# save_dir = args.save_root / f"{args.env_name.replace('/', '-')}__{args.exp_name}__{int(time.time())}"
save_dir = args.save_root
if not save_dir.exists():
save_dir.mkdir(parents=True)
if args.eval:
eval_env = gym.make(args.env_name, render_mode='rgb_array')
eval_env = ImageEnv(eval_env, stack_frames=args.num_envs, image_hw=args.image_hw)
evaluate(agent=None, eval_env=eval_env, capture_frames=False)
else:
main()