feat: single machine DDP (test)

This commit is contained in:
Ting-Jun Wang 2024-05-16 15:59:20 +08:00
parent 939aa6d92e
commit aedc6b46e9
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -6,34 +6,54 @@ from model import Network
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
model = Network().to(device)
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
def ddp_init(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '21046'
init_process_group('nccl', rank=rank, world_size=world_size)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
train_loss_sum = 0
train_correct_sum = 0
for x, y in loader:
x = x.float()
x, y = x.to(device), y.to(device)
def main(rank, world_size):
ddp_init(rank, world_size)
predict = model(x)
loss = criterion(predict, y)
loss.backward()
model = Network()
model = DDP(model, device_ids=rank)
# evaluate
train_loss_sum += loss.item()
predicted_classes = torch.argmax(predict, dim=1)
train_correct_sum += (predicted_classes == y).sum()
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
loader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=DistributedSampler(dataset, ))
optimizer.step()
optimizer.zero_grad()
print(train_loss_sum / len(loader))
print((train_correct_sum / len(dataset)).item(),'%')
print()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
train_loss_sum = 0
train_correct_sum = 0
for x, y in loader:
x = x.float()
x, y = x.to(rank), y.to(rank)
predict = model(x)
loss = criterion(predict, y)
loss.backward()
# evaluate
train_loss_sum += loss.item()
predicted_classes = torch.argmax(predict, dim=1)
train_correct_sum += (predicted_classes == y).sum()
optimizer.step()
optimizer.zero_grad()
print(train_loss_sum / len(loader))
print((train_correct_sum / len(dataset)).item(),'%')
print()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, ))