import torch import argparse from torch import optim from torch import nn from dataset import Cifar10Dataset from model import Network from torch.utils.data import DataLoader import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os import time def ddp_init(): init_process_group('nccl') torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) class Trainer(): def __init__(self, model, dataset, batch_size, optimizer, criterion): self.local_rank = int(os.environ['LOCAL_RANK']) self.global_rank = int(os.environ['RANK']) self.model = model.to(self.local_rank) self.model = DDP(self.model, device_ids=[self.local_rank]) self.dataset = dataset self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset)) self.optimizer = optimizer self.criterion = criterion def train(self, epoch_num): print("Start traininig...") start_time = time.time() for epoch in range(epoch_num): self.model.train() train_loss_sum = 0 train_correct_sum = 0 train_item_counter = 0 for x, y in self.loader: x = x.float() x, y = x.to(self.local_rank), y.to(self.local_rank) predict = self.model(x) loss = self.criterion(predict, y) loss.backward() # evaluate train_loss_sum += loss.item() predicted_classes = torch.argmax(predict, dim=1) train_correct_sum += (predicted_classes == y).sum() train_item_counter += x.shape[0] self.optimizer.step() self.optimizer.zero_grad() print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") stop_time = time.time() print(f"Training need {stop_time-start_time}s") def save(self, model_path): torch.save(self.model.state_dict(), model_path)