feat: trainer class for single/multi GPU
This commit is contained in:
parent
8f3253ff24
commit
24240f1c3a
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import argparse
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from dataset import Cifar10Dataset
|
from dataset import Cifar10Dataset
|
||||||
@ -18,41 +19,66 @@ def ddp_init(rank, world_size):
|
|||||||
init_process_group('nccl', rank=rank, world_size=world_size)
|
init_process_group('nccl', rank=rank, world_size=world_size)
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
class Trainer():
|
||||||
|
def __init__(self, rank, model, dataset, batch_size, optimizer, criterion):
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
def main(rank, world_size):
|
self.model = model.to(rank)
|
||||||
|
self.model = DDP(self.model, device_ids=[self.rank])
|
||||||
|
|
||||||
|
self.dataset = dataset
|
||||||
|
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
|
||||||
|
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.criterion = criterion
|
||||||
|
|
||||||
|
def train(self, epoch_num):
|
||||||
|
for epoch in range(epoch_num):
|
||||||
|
self.model.train()
|
||||||
|
train_loss_sum = 0
|
||||||
|
train_correct_sum = 0
|
||||||
|
train_item_counter = 0
|
||||||
|
for x, y in self.loader:
|
||||||
|
x = x.float()
|
||||||
|
x, y = x.to(self.rank), y.to(self.rank)
|
||||||
|
|
||||||
|
predict = self.model(x)
|
||||||
|
loss = self.criterion(predict, y)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# evaluate
|
||||||
|
train_loss_sum += loss.item()
|
||||||
|
predicted_classes = torch.argmax(predict, dim=1)
|
||||||
|
train_correct_sum += (predicted_classes == y).sum()
|
||||||
|
train_item_counter += x.shape[0]
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
print(f"[DEVICE {self.rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(rank, world_size, batch_size, epoch_num):
|
||||||
|
print(f'training with {world_size} GPUs')
|
||||||
|
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||||
ddp_init(rank, world_size)
|
ddp_init(rank, world_size)
|
||||||
|
|
||||||
model = Network().to(rank)
|
model = Network().to(rank)
|
||||||
model = DDP(model, device_ids=[rank])
|
|
||||||
|
|
||||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||||
loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, ))
|
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
for epoch in range(50):
|
trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion)
|
||||||
model.train()
|
trainer.train(epoch_num)
|
||||||
train_loss_sum = 0
|
|
||||||
train_correct_sum = 0
|
|
||||||
for x, y in loader:
|
|
||||||
x = x.float()
|
|
||||||
x, y = x.to(rank), y.to(rank)
|
|
||||||
|
|
||||||
predict = model(x)
|
|
||||||
loss = criterion(predict, y)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# evaluate
|
|
||||||
train_loss_sum += loss.item()
|
|
||||||
predicted_classes = torch.argmax(predict, dim=1)
|
|
||||||
train_correct_sum += (predicted_classes == y).sum()
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
print(f"[DEVICE {rank}] EPOCH {epoch} loss={train_loss_sum/len(loader)} acc={(train_correct_sum/len(dataset)).item()}")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
world_size = torch.cuda.device_count()
|
|
||||||
mp.spawn(main, args=(world_size, ), nprocs=world_size)
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--batch_size', type=int, default=32, help="batch size for training")
|
||||||
|
parser.add_argument('--epoch_num', type=int, default=50, help="training epoch")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user