From 8f3253ff246db676a32f3f6318a9b3ebbdb510a7 Mon Sep 17 00:00:00 2001 From: TING-JUN WANG Date: Thu, 16 May 2024 16:28:22 +0800 Subject: [PATCH] fix: single machine parallel training success --- train/train.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/train/train.py b/train/train.py index 428f8ff..6d47033 100644 --- a/train/train.py +++ b/train/train.py @@ -16,13 +16,14 @@ def ddp_init(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '21046' init_process_group('nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) def main(rank, world_size): ddp_init(rank, world_size) - model = Network() - model = DDP(model, device_ids=rank) + model = Network().to(rank) + model = DDP(model, device_ids=[rank]) dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, )) @@ -49,11 +50,9 @@ def main(rank, world_size): optimizer.step() optimizer.zero_grad() - print(train_loss_sum / len(loader)) - print((train_correct_sum / len(dataset)).item(),'%') - print() + print(f"[DEVICE {rank}] EPOCH {epoch} loss={train_loss_sum/len(loader)} acc={(train_correct_sum/len(dataset)).item()}") if __name__ == '__main__': world_size = torch.cuda.device_count() - mp.spawn(main, args=(world_size, )) + mp.spawn(main, args=(world_size, ), nprocs=world_size)