From 62e86d5e8b8fb1de1808103837f73b3e7bf1fda2 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Wed, 29 May 2024 04:47:50 +0800 Subject: [PATCH] feat: time evaluation in trainer --- train/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train/trainer.py b/train/trainer.py index 4d9c04c..06af416 100644 --- a/train/trainer.py +++ b/train/trainer.py @@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os +import time def ddp_init(): init_process_group('nccl') @@ -32,6 +33,7 @@ class Trainer(): def train(self, epoch_num): print("Start traininig...") + start_time = time.time() for epoch in range(epoch_num): self.model.train() train_loss_sum = 0 @@ -54,6 +56,8 @@ class Trainer(): self.optimizer.step() self.optimizer.zero_grad() print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") + stop_time = time.time() + print(f"Training need {stop_time-start_time}s") def save(self, model_path): torch.save(self.model.state_dict(), model_path)