# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py from torch import optim from torch import nn from model import Network from dataset import Cifar10Dataset from trainer import ddp_init, Trainer from model import Network BATCH_SIZE = 64 EPOCH_NUM = 5 def main(batch_size, epoch_num): print(f'training config: batch_size={batch_size}, epoch={epoch_num}') ddp_init() model = Network() dataset = Cifar10Dataset('/dataset/cifar-10-batches-py') optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() trainer = Trainer(model, dataset, batch_size, optimizer, criterion) trainer.train(epoch_num) trainer.save('/output/model.pth') if __name__ == '__main__': main(BATCH_SIZE, EPOCH_NUM)