NTU-AI-HW5/utils.py

47 lines
1.2 KiB
Python

import inspect
import sys
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
def raiseNotDefined():
filename = inspect.stack()[1][1]
line = inspect.stack()[1][2]
method = inspect.stack()[1][3]
print(f"*** Method not implemented: {method} at line {line} of {filename} ***")
sys.exit()
def seed_everything(seed, env):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
env.seed(seed)
def plot(steps, avg_scores, value_loss, img_name='output.png'):
fig, ax1 = plt.subplots()
plt.xlabel('Steps')
ax2 = ax1.twinx()
ax1.set_ylabel('AvgScores', color='tab:red')
ax1.plot(steps, avg_scores, color='tab:red', alpha=0.75)
ax1.tick_params(axis='y', labelcolor='tab:red')
ax2.set_ylabel('ValueLoss', color='tab:blue')
ax2.plot(steps, value_loss, color='tab:blue', alpha=1)
ax2.tick_params(axis='y', labelcolor='tab:blue')
fig.tight_layout()
plt.savefig(img_name)
plt.close()
YOUR_CODE_HERE = "*** YOUR CODE HERE ***"