47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
import inspect
|
|
import sys
|
|
import os
|
|
|
|
import torch
|
|
import random
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
def raiseNotDefined():
|
|
filename = inspect.stack()[1][1]
|
|
line = inspect.stack()[1][2]
|
|
method = inspect.stack()[1][3]
|
|
|
|
print(f"*** Method not implemented: {method} at line {line} of {filename} ***")
|
|
sys.exit()
|
|
|
|
def seed_everything(seed, env):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
env.seed(seed)
|
|
|
|
def plot(steps, avg_scores, value_loss, img_name='output.png'):
|
|
fig, ax1 = plt.subplots()
|
|
plt.xlabel('Steps')
|
|
ax2 = ax1.twinx()
|
|
|
|
ax1.set_ylabel('AvgScores', color='tab:red')
|
|
ax1.plot(steps, avg_scores, color='tab:red', alpha=0.75)
|
|
ax1.tick_params(axis='y', labelcolor='tab:red')
|
|
|
|
ax2.set_ylabel('ValueLoss', color='tab:blue')
|
|
ax2.plot(steps, value_loss, color='tab:blue', alpha=1)
|
|
ax2.tick_params(axis='y', labelcolor='tab:blue')
|
|
|
|
fig.tight_layout()
|
|
plt.savefig(img_name)
|
|
plt.close()
|
|
|
|
YOUR_CODE_HERE = "*** YOUR CODE HERE ***"
|
|
|