From 704a5bbf3e06386748bdb8ec959b973b2ac8b1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Sun, 12 May 2024 10:35:41 +0800 Subject: [PATCH] [utils] update precision of speed metric (#2524) when `accum_grad` is large, say 16, `steps/sec` might be less than 0.1 --- wenet/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index c16e085ea..cdf6da2b3 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -788,7 +788,7 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): if info_dict.get("cv_step", None) is not None: timer_step = info_dict['cv_step'] steps_per_second = timer.steps_per_second(timer_step) - log_str += 'steps/sec {:.1f}| '.format(steps_per_second) + log_str += 'steps/sec {:.3f}| '.format(steps_per_second) log_str += 'Batch {}/{} loss {:.6f} '.format( epoch, batch_idx + 1 if 'save_interval' not in info_dict else (step + 1) * accum_grad,