adversarial_VLNDUET/pretrain_src/utils/save.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

47 lines
1.6 KiB
Python

"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
saving utilities
"""
import json
import os
import torch
def save_training_meta(args):
os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'ckpts'), exist_ok=True)
with open(os.path.join(args.output_dir, 'logs', 'training_args.json'), 'w') as writer:
json.dump(vars(args), writer, indent=4)
model_config = json.load(open(args.model_config))
with open(os.path.join(args.output_dir, 'logs', 'model_config.json'), 'w') as writer:
json.dump(model_config, writer, indent=4)
class ModelSaver(object):
def __init__(self, output_dir, prefix='model_step', suffix='pt'):
self.output_dir = output_dir
self.prefix = prefix
self.suffix = suffix
def save(self, model, step, optimizer=None):
output_model_file = os.path.join(self.output_dir,
f"{self.prefix}_{step}.{self.suffix}")
state_dict = {}
for k, v in model.state_dict().items():
if k.startswith('module.'):
k = k[7:]
if isinstance(v, torch.Tensor):
state_dict[k] = v.cpu()
else:
state_dict[k] = v
torch.save(state_dict, output_model_file)
if optimizer is not None:
dump = {'step': step, 'optimizer': optimizer.state_dict()}
if hasattr(optimizer, '_amp_stash'):
pass # TODO fp16 optimizer
torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')