diff --git a/train/train.py b/train/train.py index 582acf8..fb3be73 100644 --- a/train/train.py +++ b/train/train.py @@ -19,7 +19,7 @@ def ddp_init(): class Trainer(): def __init__(self, model, dataset, batch_size, optimizer, criterion): self.local_rank = int(os.environ['LOCAL_RANK']) - self.global_rank = int(os.environ['GLOBAL_RANK']) + self.global_rank = int(os.environ['RANK']) self.model = model.to(self.local_rank) self.model = DDP(self.model, device_ids=[self.local_rank])