29 lines
829 B
Python
29 lines
829 B
Python
# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py
|
|
from torch import optim
|
|
from torch import nn
|
|
|
|
from model import Network
|
|
from dataset import Cifar10Dataset
|
|
from trainer import ddp_init, Trainer
|
|
from model import Network
|
|
|
|
BATCH_SIZE = 64
|
|
EPOCH_NUM = 30
|
|
|
|
def main(batch_size, epoch_num):
|
|
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
|
ddp_init()
|
|
|
|
model = Network()
|
|
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
|
trainer.train(epoch_num)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main(BATCH_SIZE, EPOCH_NUM)
|
|
|