diff --git a/train/train.py b/train/train.py index 6d47033..33bc55b 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@ import torch +import argparse from torch import optim from torch import nn from dataset import Cifar10Dataset @@ -18,41 +19,66 @@ def ddp_init(rank, world_size): init_process_group('nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) +class Trainer(): + def __init__(self, rank, model, dataset, batch_size, optimizer, criterion): + self.rank = rank -def main(rank, world_size): + self.model = model.to(rank) + self.model = DDP(self.model, device_ids=[self.rank]) + + self.dataset = dataset + self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset)) + + self.optimizer = optimizer + self.criterion = criterion + + def train(self, epoch_num): + for epoch in range(epoch_num): + self.model.train() + train_loss_sum = 0 + train_correct_sum = 0 + train_item_counter = 0 + for x, y in self.loader: + x = x.float() + x, y = x.to(self.rank), y.to(self.rank) + + predict = self.model(x) + loss = self.criterion(predict, y) + loss.backward() + + # evaluate + train_loss_sum += loss.item() + predicted_classes = torch.argmax(predict, dim=1) + train_correct_sum += (predicted_classes == y).sum() + train_item_counter += x.shape[0] + + self.optimizer.step() + self.optimizer.zero_grad() + print(f"[DEVICE {self.rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") + + + +def main(rank, world_size, batch_size, epoch_num): + print(f'training with {world_size} GPUs') + print(f'training config: batch_size={batch_size}, epoch={epoch_num}') ddp_init(rank, world_size) model = Network().to(rank) - model = DDP(model, device_ids=[rank]) - dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') - loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, )) - optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - for epoch in range(50): - model.train() - train_loss_sum = 0 - train_correct_sum = 0 - for x, y in loader: - x = x.float() - x, y = x.to(rank), y.to(rank) + trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion) + trainer.train(epoch_num) - predict = model(x) - loss = criterion(predict, y) - loss.backward() - - # evaluate - train_loss_sum += loss.item() - predicted_classes = torch.argmax(predict, dim=1) - train_correct_sum += (predicted_classes == y).sum() - - optimizer.step() - optimizer.zero_grad() - print(f"[DEVICE {rank}] EPOCH {epoch} loss={train_loss_sum/len(loader)} acc={(train_correct_sum/len(dataset)).item()}") if __name__ == '__main__': - world_size = torch.cuda.device_count() - mp.spawn(main, args=(world_size, ), nprocs=world_size) + + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=32, help="batch size for training") + parser.add_argument('--epoch_num', type=int, default=50, help="training epoch") + args = parser.parse_args() + + world_size = torch.cuda.device_count() + mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)