feat: change rank when multi machine training
This commit is contained in:
parent
874c160eae
commit
d4b9aaa1d6
@ -18,10 +18,11 @@ def ddp_init():
|
|||||||
|
|
||||||
class Trainer():
|
class Trainer():
|
||||||
def __init__(self, model, dataset, batch_size, optimizer, criterion):
|
def __init__(self, model, dataset, batch_size, optimizer, criterion):
|
||||||
self.rank = int(os.environ['LOCAL_RANK'])
|
self.local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
self.global_rank = int(os.environ['GLOBAL_RANK'])
|
||||||
|
|
||||||
self.model = model.to(self.rank)
|
self.model = model.to(self.local_rank)
|
||||||
self.model = DDP(self.model, device_ids=[self.rank])
|
self.model = DDP(self.model, device_ids=[self.local_rank])
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
|
self.loader = DataLoader(self.dataset, batch_size, shuffle=False, sampler=DistributedSampler(self.dataset))
|
||||||
@ -37,7 +38,7 @@ class Trainer():
|
|||||||
train_item_counter = 0
|
train_item_counter = 0
|
||||||
for x, y in self.loader:
|
for x, y in self.loader:
|
||||||
x = x.float()
|
x = x.float()
|
||||||
x, y = x.to(self.rank), y.to(self.rank)
|
x, y = x.to(self.local_rank), y.to(self.local_rank)
|
||||||
|
|
||||||
predict = self.model(x)
|
predict = self.model(x)
|
||||||
loss = self.criterion(predict, y)
|
loss = self.criterion(predict, y)
|
||||||
@ -51,7 +52,7 @@ class Trainer():
|
|||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
print(f"[DEVICE {self.rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user