diff --git a/train/train.py b/train/train.py index 33bc55b..eb8f4f4 100644 --- a/train/train.py +++ b/train/train.py @@ -13,17 +13,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os -def ddp_init(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '21046' - init_process_group('nccl', rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def ddp_init(): + init_process_group('nccl') + torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) class Trainer(): - def __init__(self, rank, model, dataset, batch_size, optimizer, criterion): - self.rank = rank + def __init__(self, model, dataset, batch_size, optimizer, criterion): + self.rank = int(os.environ['LOCAL_RANK']) - self.model = model.to(rank) + self.model = model.to(self.rank) self.model = DDP(self.model, device_ids=[self.rank]) self.dataset = dataset @@ -58,17 +56,23 @@ class Trainer(): -def main(rank, world_size, batch_size, epoch_num): +def main(batch_size, epoch_num): print(f'training with {world_size} GPUs') print(f'training config: batch_size={batch_size}, epoch={epoch_num}') - ddp_init(rank, world_size) + print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}') + print(f'RANK={os.environ["RANK"]}') + print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}') + print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}') + print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}') + print(f'MASTER_PORT={os.environ["MASTER_PORT"]}') + ddp_init() - model = Network().to(rank) + model = Network() dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion) + trainer = Trainer(model, dataset, batch_size, optimizer, criterion) trainer.train(epoch_num) @@ -80,5 +84,6 @@ if __name__ == '__main__': args = parser.parse_args() world_size = torch.cuda.device_count() - mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size) + main(args.batch_size, args.epoch_num) + # mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)