feat: single machine DDP (test)

This commit is contained in:
Ting-Jun Wang 2024-05-16 15:59:20 +08:00
parent 939aa6d92e
commit aedc6b46e9
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -6,11 +6,26 @@ from model import Network
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_init(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '21046'
init_process_group('nccl', rank=rank, world_size=world_size)
def main(rank, world_size):
ddp_init(rank, world_size)
model = Network()
model = DDP(model, device_ids=rank)
model = Network().to(device)
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, ))
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
@ -21,7 +36,7 @@ for epoch in range(50):
train_correct_sum = 0
for x, y in loader:
x = x.float()
x, y = x.to(device), y.to(device)
x, y = x.to(rank), y.to(rank)
predict = model(x)
loss = criterion(predict, y)
@ -37,3 +52,8 @@ for epoch in range(50):
print(train_loss_sum / len(loader))
print((train_correct_sum / len(dataset)).item(),'%')
print()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, ))