import torch from torch import optim from torch import nn from dataset import Cifar10Dataset from model import Network from torch.utils.data import DataLoader import matplotlib.pyplot as plt import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os def ddp_init(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '21046' init_process_group('nccl', rank=rank, world_size=world_size) def main(rank, world_size): ddp_init(rank, world_size) model = Network() model = DDP(model, device_ids=rank) dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, )) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() for epoch in range(50): model.train() train_loss_sum = 0 train_correct_sum = 0 for x, y in loader: x = x.float() x, y = x.to(rank), y.to(rank) predict = model(x) loss = criterion(predict, y) loss.backward() # evaluate train_loss_sum += loss.item() predicted_classes = torch.argmax(predict, dim=1) train_correct_sum += (predicted_classes == y).sum() optimizer.step() optimizer.zero_grad() print(train_loss_sum / len(loader)) print((train_correct_sum / len(dataset)).item(),'%') print() if __name__ == '__main__': world_size = torch.cuda.device_count() mp.spawn(main, args=(world_size, ))