From bdaa1d4846fbeae2c6f480434892ba0186faa665 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Wed, 29 May 2024 04:25:15 +0800 Subject: [PATCH] fix: training log --- train/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train/trainer.py b/train/trainer.py index 5a8913f..4d9c04c 100644 --- a/train/trainer.py +++ b/train/trainer.py @@ -31,6 +31,7 @@ class Trainer(): self.criterion = criterion def train(self, epoch_num): + print("Start traininig...") for epoch in range(epoch_num): self.model.train() train_loss_sum = 0 @@ -52,7 +53,7 @@ class Trainer(): self.optimizer.step() self.optimizer.zero_grad() - print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") + print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}") def save(self, model_path): torch.save(self.model.state_dict(), model_path)