47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
"""
|
|
Copyright (c) Microsoft Corporation.
|
|
Licensed under the MIT license.
|
|
|
|
saving utilities
|
|
"""
|
|
import json
|
|
import os
|
|
import torch
|
|
|
|
|
|
def save_training_meta(args):
|
|
os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True)
|
|
os.makedirs(os.path.join(args.output_dir, 'ckpts'), exist_ok=True)
|
|
|
|
with open(os.path.join(args.output_dir, 'logs', 'training_args.json'), 'w') as writer:
|
|
json.dump(vars(args), writer, indent=4)
|
|
model_config = json.load(open(args.model_config))
|
|
with open(os.path.join(args.output_dir, 'logs', 'model_config.json'), 'w') as writer:
|
|
json.dump(model_config, writer, indent=4)
|
|
|
|
|
|
class ModelSaver(object):
|
|
def __init__(self, output_dir, prefix='model_step', suffix='pt'):
|
|
self.output_dir = output_dir
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
|
|
def save(self, model, step, optimizer=None):
|
|
output_model_file = os.path.join(self.output_dir,
|
|
f"{self.prefix}_{step}.{self.suffix}")
|
|
state_dict = {}
|
|
for k, v in model.state_dict().items():
|
|
if k.startswith('module.'):
|
|
k = k[7:]
|
|
if isinstance(v, torch.Tensor):
|
|
state_dict[k] = v.cpu()
|
|
else:
|
|
state_dict[k] = v
|
|
torch.save(state_dict, output_model_file)
|
|
if optimizer is not None:
|
|
dump = {'step': step, 'optimizer': optimizer.state_dict()}
|
|
if hasattr(optimizer, '_amp_stash'):
|
|
pass # TODO fp16 optimizer
|
|
torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')
|
|
|