From d4b9aaa1d6107ad43e9846ad014fe942363566e8 Mon Sep 17 00:00:00 2001 From: TING-JUN WANG Date: Thu, 16 May 2024 23:17:35 +0800 Subject: [PATCH] feat: change rank when multi machine training --- train/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/train/train.py b/train/train.py index d8b409b..582acf8 100644 --- a/train/train.py +++ b/train/train.py @@ -18,10 +18,11 @@ def ddp_init(): class Trainer(): def __init__(self, model, dataset, batch_size, optimizer, criterion): - self.rank = int(os.environ['LOCAL_RANK']) + self.local_rank = int(os.environ['LOCAL_RANK']) + self.global_rank = int(os.environ['GLOBAL_RANK']) - self.model = model.to(self.rank) - self.model = DDP(self.model, device_ids=[self.rank]) + self.model = model.to(self.local_rank) + self.model = DDP(self.model, device_ids=[self.local_rank]) self.dataset = dataset self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset)) @@ -37,7 +38,7 @@ class Trainer(): train_item_counter = 0 for x, y in self.loader: x = x.float() - x, y = x.to(self.rank), y.to(self.rank) + x, y = x.to(self.local_rank), y.to(self.local_rank) predict = self.model(x) loss = self.criterion(predict, y) @@ -51,7 +52,7 @@ class Trainer(): self.optimizer.step() self.optimizer.zero_grad() - print(f"[DEVICE {self.rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") + print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")