feat: change rank when multi machine training

This commit is contained in:
TING-JUN WANG 2024-05-16 23:17:35 +08:00
parent 874c160eae
commit d4b9aaa1d6

View File

@ -18,10 +18,11 @@ def ddp_init():
class Trainer():
def __init__(self, model, dataset, batch_size, optimizer, criterion):
self.rank = int(os.environ['LOCAL_RANK'])
self.local_rank = int(os.environ['LOCAL_RANK'])
self.global_rank = int(os.environ['GLOBAL_RANK'])
self.model = model.to(self.rank)
self.model = DDP(self.model, device_ids=[self.rank])
self.model = model.to(self.local_rank)
self.model = DDP(self.model, device_ids=[self.local_rank])
self.dataset = dataset
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
@ -37,7 +38,7 @@ class Trainer():
train_item_counter = 0
for x, y in self.loader:
x = x.float()
x, y = x.to(self.rank), y.to(self.rank)
x, y = x.to(self.local_rank), y.to(self.local_rank)
predict = self.model(x)
loss = self.criterion(predict, y)
@ -51,7 +52,7 @@ class Trainer():
self.optimizer.step()
self.optimizer.zero_grad()
print(f"[DEVICE {self.rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")