Distributed-Training-Example/train/main.py
2024-05-20 15:29:02 +00:00

29 lines
829 B
Python

# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py
from torch import optim
from torch import nn
from model import Network
from dataset import Cifar10Dataset
from trainer import ddp_init, Trainer
from model import Network
BATCH_SIZE = 64
EPOCH_NUM = 30
def main(batch_size, epoch_num):
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
ddp_init()
model = Network()
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
trainer.train(epoch_num)
if __name__ == '__main__':
main(BATCH_SIZE, EPOCH_NUM)