diff --git a/train/main.py b/train/main.py new file mode 100644 index 0000000..e59bf3c --- /dev/null +++ b/train/main.py @@ -0,0 +1,28 @@ +# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py +from torch import optim +from torch import nn + +from model import Network +from dataset import Cifar10Dataset +from trainer import ddp_init, Trainer +from model import Network + +BATCH_SIZE = 64 +EPOCH_NUM = 30 + +def main(batch_size, epoch_num): + print(f'training config: batch_size={batch_size}, epoch={epoch_num}') + ddp_init() + + model = Network() + dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') + optimizer = optim.Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + trainer = Trainer(model, dataset, batch_size, optimizer, criterion) + trainer.train(epoch_num) + + +if __name__ == '__main__': + main(BATCH_SIZE, EPOCH_NUM) + diff --git a/train/train.py b/train/trainer.py similarity index 61% rename from train/train.py rename to train/trainer.py index fb3be73..bf14df8 100644 --- a/train/train.py +++ b/train/trainer.py @@ -55,35 +55,3 @@ class Trainer(): print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") - -def main(batch_size, epoch_num): - print(f'training with {world_size} GPUs') - print(f'training config: batch_size={batch_size}, epoch={epoch_num}') - print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}') - print(f'RANK={os.environ["RANK"]}') - print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}') - print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}') - print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}') - print(f'MASTER_PORT={os.environ["MASTER_PORT"]}') - ddp_init() - - model = Network() - dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') - optimizer = optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - trainer = Trainer(model, dataset, batch_size, optimizer, criterion) - trainer.train(epoch_num) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--batch_size', type=int, default=32, help="batch size for training") - parser.add_argument('--epoch_num', type=int, default=50, help="training epoch") - args = parser.parse_args() - - world_size = torch.cuda.device_count() - main(args.batch_size, args.epoch_num) - # mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size) -