feat: time evaluation in trainer

This commit is contained in:
Ting-Jun Wang 2024-05-29 04:47:50 +08:00
parent e5c97ca8a1
commit 62e86d5e8b
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group from torch.distributed import init_process_group, destroy_process_group
import os import os
import time
def ddp_init(): def ddp_init():
init_process_group('nccl') init_process_group('nccl')
@ -32,6 +33,7 @@ class Trainer():
def train(self, epoch_num): def train(self, epoch_num):
print("Start traininig...") print("Start traininig...")
start_time = time.time()
for epoch in range(epoch_num): for epoch in range(epoch_num):
self.model.train() self.model.train()
train_loss_sum = 0 train_loss_sum = 0
@ -54,6 +56,8 @@ class Trainer():
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
stop_time = time.time()
print(f"Training need {stop_time-start_time}s")
def save(self, model_path): def save(self, model_path):
torch.save(self.model.state_dict(), model_path) torch.save(self.model.state_dict(), model_path)