feat: time evaluation in trainer
This commit is contained in:
parent
e5c97ca8a1
commit
62e86d5e8b
@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.distributed import init_process_group, destroy_process_group
|
from torch.distributed import init_process_group, destroy_process_group
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
def ddp_init():
|
def ddp_init():
|
||||||
init_process_group('nccl')
|
init_process_group('nccl')
|
||||||
@ -32,6 +33,7 @@ class Trainer():
|
|||||||
|
|
||||||
def train(self, epoch_num):
|
def train(self, epoch_num):
|
||||||
print("Start traininig...")
|
print("Start traininig...")
|
||||||
|
start_time = time.time()
|
||||||
for epoch in range(epoch_num):
|
for epoch in range(epoch_num):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
train_loss_sum = 0
|
train_loss_sum = 0
|
||||||
@ -54,6 +56,8 @@ class Trainer():
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
print(f"[RANK {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||||
|
stop_time = time.time()
|
||||||
|
print(f"Training need {stop_time-start_time}s")
|
||||||
|
|
||||||
def save(self, model_path):
|
def save(self, model_path):
|
||||||
torch.save(self.model.state_dict(), model_path)
|
torch.save(self.model.state_dict(), model_path)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user