66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import torch
|
|
import argparse
|
|
from torch import optim
|
|
from torch import nn
|
|
from dataset import Cifar10Dataset
|
|
from model import Network
|
|
from torch.utils.data import DataLoader
|
|
|
|
import torch.multiprocessing as mp
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.distributed import init_process_group, destroy_process_group
|
|
import os
|
|
import time
|
|
|
|
def ddp_init():
|
|
init_process_group('nccl')
|
|
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
|
|
|
class Trainer():
|
|
def __init__(self, model, dataset, batch_size, optimizer, criterion):
|
|
self.local_rank = int(os.environ['LOCAL_RANK'])
|
|
self.global_rank = int(os.environ['RANK'])
|
|
|
|
self.model = model.to(self.local_rank)
|
|
self.model = DDP(self.model, device_ids=[self.local_rank])
|
|
|
|
self.dataset = dataset
|
|
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
|
|
|
|
self.optimizer = optimizer
|
|
self.criterion = criterion
|
|
|
|
def train(self, epoch_num):
|
|
print("Start traininig...")
|
|
start_time = time.time()
|
|
for epoch in range(epoch_num):
|
|
self.model.train()
|
|
train_loss_sum = 0
|
|
train_correct_sum = 0
|
|
train_item_counter = 0
|
|
for x, y in self.loader:
|
|
x = x.float()
|
|
x, y = x.to(self.local_rank), y.to(self.local_rank)
|
|
|
|
predict = self.model(x)
|
|
loss = self.criterion(predict, y)
|
|
loss.backward()
|
|
|
|
# evaluate
|
|
train_loss_sum += loss.item()
|
|
predicted_classes = torch.argmax(predict, dim=1)
|
|
train_correct_sum += (predicted_classes == y).sum()
|
|
train_item_counter += x.shape[0]
|
|
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
|
stop_time = time.time()
|
|
print(f"Training need {stop_time-start_time}s")
|
|
|
|
def save(self, model_path):
|
|
torch.save(self.model.state_dict(), model_path)
|
|
|
|
|