fix: training log
This commit is contained in:
parent
17949bd1a6
commit
bdaa1d4846
@ -31,6 +31,7 @@ class Trainer():
|
||||
self.criterion = criterion
|
||||
|
||||
def train(self, epoch_num):
|
||||
print("Start traininig...")
|
||||
for epoch in range(epoch_num):
|
||||
self.model.train()
|
||||
train_loss_sum = 0
|
||||
@ -52,7 +53,7 @@ class Trainer():
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||
print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||
|
||||
def save(self, model_path):
|
||||
torch.save(self.model.state_dict(), model_path)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user