Commit 9a662cce by xlwang

add elapsed time logging

parent 8473c700
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import torch import torch
from base import BaseTrainer from base import BaseTrainer
import math import math
# from lib.utils import inf_loop import time
class DCRNNTrainer(BaseTrainer): class DCRNNTrainer(BaseTrainer):
...@@ -49,7 +49,7 @@ class DCRNNTrainer(BaseTrainer): ...@@ -49,7 +49,7 @@ class DCRNNTrainer(BaseTrainer):
The metrics in log must have the key 'metrics'. The metrics in log must have the key 'metrics'.
""" """
self.model.train() self.model.train()
start_time = time.time()
total_loss = 0 total_loss = 0
total_metrics = np.zeros(len(self.metrics)) total_metrics = np.zeros(len(self.metrics))
for batch_idx, (data, target) in enumerate(self.data_loader.get_iterator()): for batch_idx, (data, target) in enumerate(self.data_loader.get_iterator()):
...@@ -90,7 +90,7 @@ class DCRNNTrainer(BaseTrainer): ...@@ -90,7 +90,7 @@ class DCRNNTrainer(BaseTrainer):
log = { log = {
'loss': total_loss / self.len_epoch, 'loss': total_loss / self.len_epoch,
'metrics': (total_metrics / self.len_epoch).tolist() 'metrics': (total_metrics / self.len_epoch).tolist(),
} }
if self.do_validation: if self.do_validation:
...@@ -99,7 +99,7 @@ class DCRNNTrainer(BaseTrainer): ...@@ -99,7 +99,7 @@ class DCRNNTrainer(BaseTrainer):
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step()
log.update({'Time': "{:.4f}s".format(time.time()-start_time)})
return log return log
def _valid_epoch(self, epoch): def _valid_epoch(self, epoch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment