Distributed-Training-Example/train/train.py
2024-05-16 23:31:24 +08:00

90 lines
3.2 KiB
Python

import torch
import argparse
from torch import optim
from torch import nn
from dataset import Cifar10Dataset
from model import Network
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_init():
init_process_group('nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
class Trainer():
def __init__(self, model, dataset, batch_size, optimizer, criterion):
self.local_rank = int(os.environ['LOCAL_RANK'])
self.global_rank = int(os.environ['RANK'])
self.model = model.to(self.local_rank)
self.model = DDP(self.model, device_ids=[self.local_rank])
self.dataset = dataset
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
self.optimizer = optimizer
self.criterion = criterion
def train(self, epoch_num):
for epoch in range(epoch_num):
self.model.train()
train_loss_sum = 0
train_correct_sum = 0
train_item_counter = 0
for x, y in self.loader:
x = x.float()
x, y = x.to(self.local_rank), y.to(self.local_rank)
predict = self.model(x)
loss = self.criterion(predict, y)
loss.backward()
# evaluate
train_loss_sum += loss.item()
predicted_classes = torch.argmax(predict, dim=1)
train_correct_sum += (predicted_classes == y).sum()
train_item_counter += x.shape[0]
self.optimizer.step()
self.optimizer.zero_grad()
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
def main(batch_size, epoch_num):
print(f'training with {world_size} GPUs')
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
print(f'RANK={os.environ["RANK"]}')
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
ddp_init()
model = Network()
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
trainer.train(epoch_num)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=32, help="batch size for training")
parser.add_argument('--epoch_num', type=int, default=50, help="training epoch")
args = parser.parse_args()
world_size = torch.cuda.device_count()
main(args.batch_size, args.epoch_num)
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)