Skip to content

Commit eb132da

Browse files
authored
additional instruction for the grad_scale is too small error (k2-fsa#1550)
1 parent 15bd9a8 commit eb132da

File tree

49 files changed

+145
-147
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+145
-147
lines changed

egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
)
9090
from icefall.dist import cleanup_dist, setup_dist
9191
from icefall.env import get_env_info
92+
from icefall.err import raise_grad_scale_is_too_small_error
9293
from icefall.lexicon import Lexicon
9394
from icefall.utils import (
9495
AttributeDict,
@@ -881,9 +882,7 @@ def train_one_epoch(
881882
if cur_grad_scale < 0.01:
882883
logging.warning(f"Grad scale is small: {cur_grad_scale}")
883884
if cur_grad_scale < 1.0e-05:
884-
raise RuntimeError(
885-
f"grad_scale is too small, exiting: {cur_grad_scale}"
886-
)
885+
raise_grad_scale_is_too_small_error()
887886
if batch_idx % params.log_interval == 0:
888887
cur_lr = scheduler.get_last_lr()[0]
889888
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

egs/aishell/ASR/pruned_transducer_stateless7/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
)
8686
from icefall.dist import cleanup_dist, setup_dist
8787
from icefall.env import get_env_info
88+
from icefall.err import raise_grad_scale_is_too_small_error
8889
from icefall.hooks import register_inf_check_hooks
8990
from icefall.lexicon import Lexicon
9091
from icefall.utils import (
@@ -878,9 +879,7 @@ def train_one_epoch(
878879
if cur_grad_scale < 0.01:
879880
logging.warning(f"Grad scale is small: {cur_grad_scale}")
880881
if cur_grad_scale < 1.0e-05:
881-
raise RuntimeError(
882-
f"grad_scale is too small, exiting: {cur_grad_scale}"
883-
)
882+
raise_grad_scale_is_too_small_error(cur_grad_scale)
884883
if batch_idx % params.log_interval == 0:
885884
cur_lr = scheduler.get_last_lr()[0]
886885
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from icefall.dist import cleanup_dist, setup_dist
8080
from icefall.env import get_env_info
81+
from icefall.err import raise_grad_scale_is_too_small_error
8182
from icefall.hooks import register_inf_check_hooks
8283
from icefall.utils import (
8384
AttributeDict,
@@ -871,9 +872,7 @@ def train_one_epoch(
871872
if cur_grad_scale < 0.01:
872873
logging.warning(f"Grad scale is small: {cur_grad_scale}")
873874
if cur_grad_scale < 1.0e-05:
874-
raise RuntimeError(
875-
f"grad_scale is too small, exiting: {cur_grad_scale}"
876-
)
875+
raise_grad_scale_is_too_small_error(cur_grad_scale)
877876

878877
if batch_idx % params.log_interval == 0:
879878
cur_lr = scheduler.get_last_lr()[0]

egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from icefall.dist import cleanup_dist, setup_dist
8080
from icefall.env import get_env_info
81+
from icefall.err import raise_grad_scale_is_too_small_error
8182
from icefall.hooks import register_inf_check_hooks
8283
from icefall.lexicon import Lexicon
8384
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -882,9 +883,7 @@ def train_one_epoch(
882883
if cur_grad_scale < 0.01:
883884
logging.warning(f"Grad scale is small: {cur_grad_scale}")
884885
if cur_grad_scale < 1.0e-05:
885-
raise RuntimeError(
886-
f"grad_scale is too small, exiting: {cur_grad_scale}"
887-
)
886+
raise_grad_scale_is_too_small_error(cur_grad_scale)
888887

889888
if batch_idx % params.log_interval == 0:
890889
cur_lr = scheduler.get_last_lr()[0]

egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from icefall.dist import cleanup_dist, setup_dist
8080
from icefall.env import get_env_info
81+
from icefall.err import raise_grad_scale_is_too_small_error
8182
from icefall.hooks import register_inf_check_hooks
8283
from icefall.lexicon import Lexicon
8384
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -881,9 +882,7 @@ def train_one_epoch(
881882
if cur_grad_scale < 0.01:
882883
logging.warning(f"Grad scale is small: {cur_grad_scale}")
883884
if cur_grad_scale < 1.0e-05:
884-
raise RuntimeError(
885-
f"grad_scale is too small, exiting: {cur_grad_scale}"
886-
)
885+
raise_grad_scale_is_too_small_error(cur_grad_scale)
887886

888887
if batch_idx % params.log_interval == 0:
889888
cur_lr = scheduler.get_last_lr()[0]

egs/aishell/ASR/zipformer/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
)
8787
from icefall.dist import cleanup_dist, setup_dist
8888
from icefall.env import get_env_info
89+
from icefall.err import raise_grad_scale_is_too_small_error
8990
from icefall.hooks import register_inf_check_hooks
9091
from icefall.lexicon import Lexicon
9192
from icefall.utils import (
@@ -985,9 +986,7 @@ def save_bad_model(suffix: str = ""):
985986
logging.warning(f"Grad scale is small: {cur_grad_scale}")
986987
if cur_grad_scale < 1.0e-05:
987988
save_bad_model()
988-
raise RuntimeError(
989-
f"grad_scale is too small, exiting: {cur_grad_scale}"
990-
)
989+
raise_grad_scale_is_too_small_error(cur_grad_scale)
991990

992991
if batch_idx % params.log_interval == 0:
993992
cur_lr = max(scheduler.get_last_lr())

egs/aishell/ASR/zipformer/train_bbpe.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
update_averaged_model,
8484
)
8585
from icefall.dist import cleanup_dist, setup_dist
86+
from icefall.err import raise_grad_scale_is_too_small_error
8687
from icefall.hooks import register_inf_check_hooks
8788
from icefall.utils import (
8889
AttributeDict,
@@ -570,9 +571,7 @@ def save_bad_model(suffix: str = ""):
570571
logging.warning(f"Grad scale is small: {cur_grad_scale}")
571572
if cur_grad_scale < 1.0e-05:
572573
save_bad_model()
573-
raise RuntimeError(
574-
f"grad_scale is too small, exiting: {cur_grad_scale}"
575-
)
574+
raise_grad_scale_is_too_small_error(cur_grad_scale)
576575

577576
if batch_idx % params.log_interval == 0:
578577
cur_lr = max(scheduler.get_last_lr())

egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
from icefall.dist import cleanup_dist, setup_dist
7272
from icefall.env import get_env_info
73+
from icefall.err import raise_grad_scale_is_too_small_error
7374
from icefall.hooks import register_inf_check_hooks
7475
from icefall.lexicon import Lexicon
7576
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -851,9 +852,7 @@ def train_one_epoch(
851852
if cur_grad_scale < 0.01:
852853
logging.warning(f"Grad scale is small: {cur_grad_scale}")
853854
if cur_grad_scale < 1.0e-05:
854-
raise RuntimeError(
855-
f"grad_scale is too small, exiting: {cur_grad_scale}"
856-
)
855+
raise_grad_scale_is_too_small_error(cur_grad_scale)
857856

858857
if batch_idx % params.log_interval == 0:
859858
cur_lr = scheduler.get_last_lr()[0]

egs/ami/ASR/pruned_transducer_stateless7/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
)
7070
from icefall.dist import cleanup_dist, setup_dist
7171
from icefall.env import get_env_info
72+
from icefall.err import raise_grad_scale_is_too_small_error
7273
from icefall.hooks import register_inf_check_hooks
7374
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
7475

@@ -842,9 +843,7 @@ def train_one_epoch(
842843
if cur_grad_scale < 0.01:
843844
logging.warning(f"Grad scale is small: {cur_grad_scale}")
844845
if cur_grad_scale < 1.0e-05:
845-
raise RuntimeError(
846-
f"grad_scale is too small, exiting: {cur_grad_scale}"
847-
)
846+
raise_grad_scale_is_too_small_error(cur_grad_scale)
848847

849848
if batch_idx % params.log_interval == 0:
850849
cur_lr = scheduler.get_last_lr()[0]

egs/ami/SURT/dprnn_zipformer/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from icefall.dist import cleanup_dist, setup_dist
7777
from icefall.env import get_env_info
78+
from icefall.err import raise_grad_scale_is_too_small_error
7879
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
7980

8081
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1138,9 +1139,7 @@ def train_one_epoch(
11381139
if cur_grad_scale < 0.01:
11391140
logging.warning(f"Grad scale is small: {cur_grad_scale}")
11401141
if cur_grad_scale < 1.0e-05:
1141-
raise RuntimeError(
1142-
f"grad_scale is too small, exiting: {cur_grad_scale}"
1143-
)
1142+
raise_grad_scale_is_too_small_error(cur_grad_scale)
11441143

11451144
if batch_idx % params.log_interval == 0:
11461145
cur_lr = scheduler.get_last_lr()[0]

egs/ami/SURT/dprnn_zipformer/train_adapt.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from icefall.dist import cleanup_dist, setup_dist
7777
from icefall.env import get_env_info
78+
from icefall.err import raise_grad_scale_is_too_small_error
7879
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
7980

8081
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1129,9 +1130,7 @@ def train_one_epoch(
11291130
if cur_grad_scale < 0.01:
11301131
logging.warning(f"Grad scale is small: {cur_grad_scale}")
11311132
if cur_grad_scale < 1.0e-05:
1132-
raise RuntimeError(
1133-
f"grad_scale is too small, exiting: {cur_grad_scale}"
1134-
)
1133+
raise_grad_scale_is_too_small_error(cur_grad_scale)
11351134

11361135
if batch_idx % params.log_interval == 0:
11371136
cur_lr = scheduler.get_last_lr()[0]

egs/commonvoice/ASR/pruned_transducer_stateless7/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
)
8080
from icefall.dist import cleanup_dist, setup_dist
8181
from icefall.env import get_env_info
82+
from icefall.err import raise_grad_scale_is_too_small_error
8283
from icefall.hooks import register_inf_check_hooks
8384
from icefall.utils import (
8485
AttributeDict,
@@ -871,9 +872,7 @@ def train_one_epoch(
871872
if cur_grad_scale < 0.01:
872873
logging.warning(f"Grad scale is small: {cur_grad_scale}")
873874
if cur_grad_scale < 1.0e-05:
874-
raise RuntimeError(
875-
f"grad_scale is too small, exiting: {cur_grad_scale}"
876-
)
875+
raise_grad_scale_is_too_small_error(cur_grad_scale)
877876

878877
if batch_idx % params.log_interval == 0:
879878
cur_lr = scheduler.get_last_lr()[0]

egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,7 @@ def train_one_epoch(
889889
if cur_grad_scale < 0.01:
890890
logging.warning(f"Grad scale is small: {cur_grad_scale}")
891891
if cur_grad_scale < 1.0e-05:
892-
raise RuntimeError(
893-
f"grad_scale is too small, exiting: {cur_grad_scale}"
894-
)
892+
raise RuntimeError(f", exiting: {cur_grad_scale}")
895893

896894
if batch_idx % params.log_interval == 0:
897895
cur_lr = scheduler.get_last_lr()[0]

egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
)
8282
from icefall.dist import cleanup_dist, setup_dist
8383
from icefall.env import get_env_info
84+
from icefall.err import raise_grad_scale_is_too_small_error
8485
from icefall.hooks import register_inf_check_hooks
8586
from icefall.utils import (
8687
AttributeDict,
@@ -965,9 +966,7 @@ def train_one_epoch(
965966
if cur_grad_scale < 0.01:
966967
logging.warning(f"Grad scale is small: {cur_grad_scale}")
967968
if cur_grad_scale < 1.0e-05:
968-
raise RuntimeError(
969-
f"grad_scale is too small, exiting: {cur_grad_scale}"
970-
)
969+
raise_grad_scale_is_too_small_error(cur_grad_scale)
971970

972971
if batch_idx % params.log_interval == 0:
973972
cur_lr = scheduler.get_last_lr()[0]

egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from icefall.dist import cleanup_dist, setup_dist
8080
from icefall.env import get_env_info
81+
from icefall.err import raise_grad_scale_is_too_small_error
8182
from icefall.hooks import register_inf_check_hooks
8283
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
8384

@@ -888,9 +889,7 @@ def train_one_epoch(
888889
if cur_grad_scale < 0.01:
889890
logging.warning(f"Grad scale is small: {cur_grad_scale}")
890891
if cur_grad_scale < 1.0e-05:
891-
raise RuntimeError(
892-
f"grad_scale is too small, exiting: {cur_grad_scale}"
893-
)
892+
raise_grad_scale_is_too_small_error(cur_grad_scale)
894893

895894
if batch_idx % params.log_interval == 0:
896895
cur_lr = scheduler.get_last_lr()[0]

egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
)
8282
from icefall.dist import cleanup_dist, setup_dist
8383
from icefall.env import get_env_info
84+
from icefall.err import raise_grad_scale_is_too_small_error
8485
from icefall.hooks import register_inf_check_hooks
8586
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
8687

@@ -909,9 +910,7 @@ def train_one_epoch(
909910
if cur_grad_scale < 0.01:
910911
logging.warning(f"Grad scale is small: {cur_grad_scale}")
911912
if cur_grad_scale < 1.0e-05:
912-
raise RuntimeError(
913-
f"grad_scale is too small, exiting: {cur_grad_scale}"
914-
)
913+
raise_grad_scale_is_too_small_error(cur_grad_scale)
915914

916915
if batch_idx % params.log_interval == 0:
917916
cur_lr = scheduler.get_last_lr()[0]

egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
)
8282
from icefall.dist import cleanup_dist, setup_dist
8383
from icefall.env import get_env_info
84+
from icefall.err import raise_grad_scale_is_too_small_error
8485
from icefall.hooks import register_inf_check_hooks
8586
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
8687

@@ -908,9 +909,7 @@ def train_one_epoch(
908909
if cur_grad_scale < 0.01:
909910
logging.warning(f"Grad scale is small: {cur_grad_scale}")
910911
if cur_grad_scale < 1.0e-05:
911-
raise RuntimeError(
912-
f"grad_scale is too small, exiting: {cur_grad_scale}"
913-
)
912+
raise_grad_scale_is_too_small_error(cur_grad_scale)
914913

915914
if batch_idx % params.log_interval == 0:
916915
cur_lr = scheduler.get_last_lr()[0]

egs/gigaspeech/ASR/zipformer/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
)
9090
from icefall.dist import cleanup_dist, setup_dist
9191
from icefall.env import get_env_info
92+
from icefall.err import raise_grad_scale_is_too_small_error
9293
from icefall.hooks import register_inf_check_hooks
9394
from icefall.utils import (
9495
AttributeDict,
@@ -1031,9 +1032,7 @@ def save_bad_model(suffix: str = ""):
10311032
logging.warning(f"Grad scale is small: {cur_grad_scale}")
10321033
if cur_grad_scale < 1.0e-05:
10331034
save_bad_model()
1034-
raise RuntimeError(
1035-
f"grad_scale is too small, exiting: {cur_grad_scale}"
1036-
)
1035+
raise_grad_scale_is_too_small_error(cur_grad_scale)
10371036

10381037
if batch_idx % params.log_interval == 0:
10391038
cur_lr = max(scheduler.get_last_lr())

egs/gigaspeech/KWS/zipformer/finetune.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
)
101101
from icefall.dist import cleanup_dist, setup_dist
102102
from icefall.env import get_env_info
103+
from icefall.err import raise_grad_scale_is_too_small_error
103104
from icefall.hooks import register_inf_check_hooks
104105
from icefall.utils import (
105106
AttributeDict,
@@ -371,9 +372,7 @@ def save_bad_model(suffix: str = ""):
371372
logging.warning(f"Grad scale is small: {cur_grad_scale}")
372373
if cur_grad_scale < 1.0e-05:
373374
save_bad_model()
374-
raise RuntimeError(
375-
f"grad_scale is too small, exiting: {cur_grad_scale}"
376-
)
375+
raise_grad_scale_is_too_small_error(cur_grad_scale)
377376

378377
if batch_idx % params.log_interval == 0:
379378
cur_lr = max(scheduler.get_last_lr())

egs/gigaspeech/KWS/zipformer/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
)
9090
from icefall.dist import cleanup_dist, setup_dist
9191
from icefall.env import get_env_info
92+
from icefall.err import raise_grad_scale_is_too_small_error
9293
from icefall.hooks import register_inf_check_hooks
9394
from icefall.utils import (
9495
AttributeDict,
@@ -1034,9 +1035,7 @@ def save_bad_model(suffix: str = ""):
10341035
logging.warning(f"Grad scale is small: {cur_grad_scale}")
10351036
if cur_grad_scale < 1.0e-05:
10361037
save_bad_model()
1037-
raise RuntimeError(
1038-
f"grad_scale is too small, exiting: {cur_grad_scale}"
1039-
)
1038+
raise_grad_scale_is_too_small_error(cur_grad_scale)
10401039

10411040
if batch_idx % params.log_interval == 0:
10421041
cur_lr = max(scheduler.get_last_lr())

egs/libricss/SURT/dprnn_zipformer/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
)
8686
from icefall.dist import cleanup_dist, setup_dist
8787
from icefall.env import get_env_info
88+
from icefall.err import raise_grad_scale_is_too_small_error
8889
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
8990

9091
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1169,9 +1170,7 @@ def train_one_epoch(
11691170
if cur_grad_scale < 0.01:
11701171
logging.warning(f"Grad scale is small: {cur_grad_scale}")
11711172
if cur_grad_scale < 1.0e-05:
1172-
raise RuntimeError(
1173-
f"grad_scale is too small, exiting: {cur_grad_scale}"
1174-
)
1173+
raise_grad_scale_is_too_small_error(cur_grad_scale)
11751174

11761175
if batch_idx % params.log_interval == 0:
11771176
cur_lr = scheduler.get_last_lr()[0]

0 commit comments

Comments
 (0)