feat: torchrun on single machine
This commit is contained in:
parent
24240f1c3a
commit
233bec6d1c
@ -13,17 +13,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.distributed import init_process_group, destroy_process_group
|
from torch.distributed import init_process_group, destroy_process_group
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def ddp_init(rank, world_size):
|
def ddp_init():
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
init_process_group('nccl')
|
||||||
os.environ['MASTER_PORT'] = '21046'
|
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
||||||
init_process_group('nccl', rank=rank, world_size=world_size)
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
|
|
||||||
class Trainer():
|
class Trainer():
|
||||||
def __init__(self, rank, model, dataset, batch_size, optimizer, criterion):
|
def __init__(self, model, dataset, batch_size, optimizer, criterion):
|
||||||
self.rank = rank
|
self.rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
|
||||||
self.model = model.to(rank)
|
self.model = model.to(self.rank)
|
||||||
self.model = DDP(self.model, device_ids=[self.rank])
|
self.model = DDP(self.model, device_ids=[self.rank])
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@ -58,17 +56,23 @@ class Trainer():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(rank, world_size, batch_size, epoch_num):
|
def main(batch_size, epoch_num):
|
||||||
print(f'training with {world_size} GPUs')
|
print(f'training with {world_size} GPUs')
|
||||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||||
ddp_init(rank, world_size)
|
print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
|
||||||
|
print(f'RANK={os.environ["RANK"]}')
|
||||||
|
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
|
||||||
|
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
|
||||||
|
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
|
||||||
|
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
|
||||||
|
ddp_init()
|
||||||
|
|
||||||
model = Network().to(rank)
|
model = Network()
|
||||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
trainer = Trainer(rank, model, dataset, batch_size, optimizer, criterion)
|
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||||
trainer.train(epoch_num)
|
trainer.train(epoch_num)
|
||||||
|
|
||||||
|
|
||||||
@ -80,5 +84,6 @@ if __name__ == '__main__':
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
world_size = torch.cuda.device_count()
|
world_size = torch.cuda.device_count()
|
||||||
mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
main(args.batch_size, args.epoch_num)
|
||||||
|
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user