feat: save()
This commit is contained in:
parent
20fd2fbe08
commit
0a287e3b46
@ -8,20 +8,20 @@ from trainer import ddp_init, Trainer
|
|||||||
from model import Network
|
from model import Network
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
EPOCH_NUM = 30
|
EPOCH_NUM = 5
|
||||||
|
|
||||||
def main(batch_size, epoch_num):
|
def main(batch_size, epoch_num):
|
||||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||||
ddp_init()
|
ddp_init()
|
||||||
|
|
||||||
model = Network()
|
model = Network()
|
||||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
dataset = Cifar10Dataset('/dataset/cifar-10-batches-py')
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||||
trainer.train(epoch_num)
|
trainer.train(epoch_num)
|
||||||
|
trainer.save('/output/model.pth')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(BATCH_SIZE, EPOCH_NUM)
|
main(BATCH_SIZE, EPOCH_NUM)
|
||||||
|
|||||||
@ -54,4 +54,7 @@ class Trainer():
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||||
|
|
||||||
|
def save(self, model_path):
|
||||||
|
torch.save(self.model.state_dict(), model_path)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user