From aedc6b46e962169a3ec42cdf6c4908d919e9d24f Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Thu, 16 May 2024 15:59:20 +0800 Subject: [PATCH] feat: single machine DDP (test) --- train/train.py | 70 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/train/train.py b/train/train.py index a8d8fdf..428f8ff 100644 --- a/train/train.py +++ b/train/train.py @@ -6,34 +6,54 @@ from model import Network from torch.utils.data import DataLoader import matplotlib.pyplot as plt -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group +import os -model = Network().to(device) -dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') -loader = DataLoader(dataset, batch_size=32, shuffle=True) +def ddp_init(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '21046' + init_process_group('nccl', rank=rank, world_size=world_size) -optimizer = optim.Adam(model.parameters(), lr=0.001) -criterion = nn.CrossEntropyLoss() -for epoch in range(50): - model.train() - train_loss_sum = 0 - train_correct_sum = 0 - for x, y in loader: - x = x.float() - x, y = x.to(device), y.to(device) +def main(rank, world_size): + ddp_init(rank, world_size) - predict = model(x) - loss = criterion(predict, y) - loss.backward() + model = Network() + model = DDP(model, device_ids=rank) - # evaluate - train_loss_sum += loss.item() - predicted_classes = torch.argmax(predict, dim=1) - train_correct_sum += (predicted_classes == y).sum() + dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') + loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, )) - optimizer.step() - optimizer.zero_grad() - print(train_loss_sum / len(loader)) - print((train_correct_sum / len(dataset)).item(),'%') - print() + optimizer = optim.Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + for epoch in range(50): + model.train() + train_loss_sum = 0 + train_correct_sum = 0 + for x, y in loader: + x = x.float() + x, y = x.to(rank), y.to(rank) + + predict = model(x) + loss = criterion(predict, y) + loss.backward() + + # evaluate + train_loss_sum += loss.item() + predicted_classes = torch.argmax(predict, dim=1) + train_correct_sum += (predicted_classes == y).sum() + + optimizer.step() + optimizer.zero_grad() + print(train_loss_sum / len(loader)) + print((train_correct_sum / len(dataset)).item(),'%') + print() + +if __name__ == '__main__': + world_size = torch.cuda.device_count() + mp.spawn(main, args=(world_size, )) +