feat: torchrun on single machine

This commit is contained in:
TING-JUN WANG 2024-05-16 20:55:18 +08:00
parent 24240f1c3a
commit 233bec6d1c

View File

@ -13,17 +13,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group from torch.distributed import init_process_group, destroy_process_group
import os import os
def ddp_init(rank, world_size): def ddp_init():
os.environ['MASTER_ADDR'] = 'localhost' init_process_group('nccl')
os.environ['MASTER_PORT'] = '21046' torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
class Trainer(): class Trainer():
def __init__(self, rank, model, dataset, batch_size, optimizer, criterion): def __init__(self, model, dataset, batch_size, optimizer, criterion):
self.rank = rank self.rank = int(os.environ['LOCAL_RANK'])
self.model = model.to(rank) self.model = model.to(self.rank)
self.model = DDP(self.model, device_ids=[self.rank]) self.model = DDP(self.model, device_ids=[self.rank])
self.dataset = dataset self.dataset = dataset
@ -58,17 +56,23 @@ class Trainer():
def main(rank, world_size, batch_size, epoch_num): def main(batch_size, epoch_num):
print(f'training with {world_size} GPUs') print(f'training with {world_size} GPUs')
print(f'training config: batch_size={batch_size}, epoch={epoch_num}') print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
ddp_init(rank, world_size) print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
print(f'RANK={os.environ["RANK"]}')
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
ddp_init()
model = Network().to(rank) model = Network()
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion) trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
trainer.train(epoch_num) trainer.train(epoch_num)
@ -80,5 +84,6 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size) main(args.batch_size, args.epoch_num)
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)