fix: global rank

This commit is contained in:
TING-JUN WANG 2024-05-16 23:31:24 +08:00
parent d4b9aaa1d6
commit 86e0c50a65

View File

@ -19,7 +19,7 @@ def ddp_init():
class Trainer():
def __init__(self, model, dataset, batch_size, optimizer, criterion):
self.local_rank = int(os.environ['LOCAL_RANK'])
self.global_rank = int(os.environ['GLOBAL_RANK'])
self.global_rank = int(os.environ['RANK'])
self.model = model.to(self.local_rank)
self.model = DDP(self.model, device_ids=[self.local_rank])