feat: torchrun on single machine
This commit is contained in:
parent
24240f1c3a
commit
233bec6d1c
@ -13,17 +13,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed import init_process_group, destroy_process_group
|
||||
import os
|
||||
|
||||
def ddp_init(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '21046'
|
||||
init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
def ddp_init():
|
||||
init_process_group('nccl')
|
||||
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
||||
|
||||
class Trainer():
|
||||
def __init__(self, rank, model, dataset, batch_size, optimizer, criterion):
|
||||
self.rank = rank
|
||||
def __init__(self, model, dataset, batch_size, optimizer, criterion):
|
||||
self.rank = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
self.model = model.to(rank)
|
||||
self.model = model.to(self.rank)
|
||||
self.model = DDP(self.model, device_ids=[self.rank])
|
||||
|
||||
self.dataset = dataset
|
||||
@ -58,17 +56,23 @@ class Trainer():
|
||||
|
||||
|
||||
|
||||
def main(rank, world_size, batch_size, epoch_num):
|
||||
def main(batch_size, epoch_num):
|
||||
print(f'training with {world_size} GPUs')
|
||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||
ddp_init(rank, world_size)
|
||||
print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
|
||||
print(f'RANK={os.environ["RANK"]}')
|
||||
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
|
||||
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
|
||||
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
|
||||
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
|
||||
ddp_init()
|
||||
|
||||
model = Network().to(rank)
|
||||
model = Network()
|
||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion)
|
||||
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||
trainer.train(epoch_num)
|
||||
|
||||
|
||||
@ -80,5 +84,6 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
||||
main(args.batch_size, args.epoch_num)
|
||||
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user