Skip to content

Commit c814dc5

Browse files
authored
Fix nan issues in streaming transformer encoder (#53)
* [aps,tests,cmd] fix nan issues in streaming transformer encoder * [aps] move sos padding from ASR forward to task class * [aps,examples] add chimera++ nnet & loss * [aps,conf,examples] update wham recipe * [aps,conf,examples] update wham results * [aps,tests] fix ci errors
1 parent 571c845 commit c814dc5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1143
-618
lines changed

aps/asr/att.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch as th
77
import torch.nn as nn
8-
import torch.nn.functional as tf
98

109
import aps.asr.beam_search.att as att_api
1110
import aps.asr.beam_search.transformer as xfmr_api
@@ -98,7 +97,7 @@ def forward(self,
9897
Args:
9998
x_pad: N x Ti x D or N x S
10099
x_len: N or None
101-
y_pad: N x To
100+
y_pad: N x To (start with sos)
102101
y_len: N or None, not used here
103102
ssr: schedule sampling rate
104103
Return:
@@ -110,13 +109,11 @@ def forward(self,
110109
self.att_net.clear()
111110
# go through feature extractor & encoder
112111
enc_out, enc_ctc, enc_len = self._training_prep(x_pad, x_len)
113-
# N x To+1
114-
tgt_pad = tf.pad(y_pad, (1, 0), value=self.sos)
115-
# N x (To+1), pad SOS
112+
# N x To
116113
dec_out, _ = self.decoder(self.att_net,
117114
enc_out,
118115
enc_len,
119-
tgt_pad,
116+
y_pad,
120117
schedule_sampling=ssr)
121118
return dec_out, enc_ctc, enc_len
122119

@@ -262,7 +259,7 @@ def forward(self,
262259
Args:
263260
x_pad: N x Ti x D or N x S
264261
x_len: N or None
265-
y_pad: N x To
262+
y_pad: N x To (start with eos)
266263
y_len: N or None
267264
ssr: not used here, left for future
268265
Return:
@@ -272,10 +269,8 @@ def forward(self,
272269
"""
273270
# go through feature extractor & encoder
274271
enc_out, enc_ctc, enc_len = self._training_prep(x_pad, x_len)
275-
# N x To+1
276-
tgt_pad = tf.pad(y_pad, (1, 0), value=self.sos)
277-
# N x To+1 x D
278-
dec_out = self.decoder(enc_out, enc_len, tgt_pad, y_len + 1)
272+
# N x To x D
273+
dec_out = self.decoder(enc_out, enc_len, y_pad, y_len)
279274
return dec_out, enc_ctc, enc_len
280275

281276
def greedy_search(self,

aps/asr/base/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn.functional as tf
1010

1111
from typing import Optional, Tuple
12+
from aps.const import NEG_INF
1213
from aps.libs import Register
1314

1415
AsrAtt = Register("asr_att")
@@ -65,7 +66,7 @@ def softmax(self, score: th.Tensor, enc_len: Optional[th.Tensor],
6566
if pad_mask is None:
6667
raise RuntimeError("Attention: pad_mask should not be None "
6768
"when enc_len is not None")
68-
score = score.masked_fill(pad_mask, float("-inf"))
69+
score = score.masked_fill(pad_mask, NEG_INF)
6970
return tf.softmax(score, dim=-1)
7071

7172
def clear(self):

aps/asr/base/encoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class RNNEncoderBase(EncoderBase):
9292
def __init__(self,
9393
inp_features: int,
9494
out_features: int,
95-
rnns: nn.Module,
95+
impl: nn.Module,
9696
input_proj: int = -1,
9797
hidden: int = -1,
9898
bidirectional: bool = False,
@@ -105,7 +105,7 @@ def __init__(self,
105105
self.proj = nn.Linear(inp_features, input_proj)
106106
else:
107107
self.proj = None
108-
self.rnns = rnns
108+
self.impl = impl
109109
factor = 2 if bidirectional else 1
110110
if out_features > 0:
111111
self.outp = nn.Linear(hidden * factor, out_features)
@@ -173,11 +173,11 @@ def __init__(self,
173173
non_linear=non_linear)
174174

175175
def flat(self):
176-
self.rnns.flatten_parameters()
176+
self.impl.flatten_parameters()
177177

178178
def _forward(self, inp: th.Tensor,
179179
inp_len: Optional[th.Tensor]) -> th.Tensor:
180-
return var_len_rnn_forward(self.rnns,
180+
return var_len_rnn_forward(self.impl,
181181
inp,
182182
inp_len=inp_len,
183183
enforce_sorted=False,
@@ -219,7 +219,7 @@ def __init__(self,
219219

220220
def _forward(self, inp: th.Tensor,
221221
inp_len: Optional[th.Tensor]) -> th.Tensor:
222-
return self.rnns(inp)[0]
222+
return self.impl(inp)[0]
223223

224224

225225
@BaseEncoder.register("variant_rnn")

aps/asr/beam_search/ctc.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88

99
from collections import defaultdict
10-
from aps.const import NEG_INF
10+
from aps.const import MIN_F32
1111
from aps.utils import get_logger
1212
from typing import Dict, List, Union
1313

@@ -56,7 +56,7 @@ def beam_search(self,
5656
topk_score, topk_token = th.topk(ctc_prob, beam_size, -1)
5757
T, V = ctc_prob.shape
5858
logger.info(f"--- shape of the encoder output (CTC): {T} x {V}")
59-
neg_inf = th.tensor(NEG_INF).to(ctc_prob.device)
59+
neg_inf = th.tensor(MIN_F32).to(ctc_prob.device)
6060
zero = th.tensor(0.0).to(ctc_prob.device)
6161
# (prefix, log_pb, log_pn)
6262
# NOTE: actually do not need sos/eos here, just place it in the sentence
@@ -131,7 +131,7 @@ def viterbi_align(self, ctc_enc: th.Tensor, dec_seq: th.Tensor) -> Dict:
131131

132132
dec_seq = dec_seq.tolist()
133133
# T x U*2+1
134-
score = NEG_INF * th.ones(T, U * 2 + 1, device=ctc_prob.device)
134+
score = MIN_F32 * th.ones(T, U * 2 + 1, device=ctc_prob.device)
135135
point = -1 * th.ones(
136136
T, U * 2 + 1, dtype=th.int32, device=ctc_prob.device)
137137

@@ -223,9 +223,9 @@ def __init__(self,
223223
self.blank = -1
224224
self.offset = th.arange(batch_size, device=self.device)
225225
self.beam_size = beam_size
226-
# eq (51) NEG_INF ~ log(0), T x N
226+
# eq (51) MIN_F32 ~ log(0), T x N
227227
self.gamma_n_g = th.full((self.T, batch_size),
228-
NEG_INF,
228+
MIN_F32,
229229
device=self.device)
230230
self.gamma_b_g = th.zeros(self.T, batch_size, device=self.device)
231231
# eq (52)
@@ -235,7 +235,7 @@ def __init__(self,
235235
self.ctc_prob[t, self.blank])
236236
# ctc score in previous steps
237237
self.ctc_score = th.zeros(1, batch_size, device=self.device)
238-
self.neg_inf = th.tensor(NEG_INF).to(self.device)
238+
self.neg_inf = th.tensor(MIN_F32).to(self.device)
239239

240240
def update_var(self, point: Union[th.Tensor, int]) -> None:
241241
"""
@@ -276,8 +276,8 @@ def forward(self, g: th.Tensor, c: th.Tensor) -> th.Tensor:
276276
# zero based
277277
glen = g.shape[-1] - 1
278278
start = max(glen, 1)
279-
gamma_n_h[start - 1] = self.ctc_prob[0, c] if glen == 0 else NEG_INF
280-
gamma_b_h[start - 1] = NEG_INF
279+
gamma_n_h[start - 1] = self.ctc_prob[0, c] if glen == 0 else MIN_F32
280+
gamma_b_h[start - 1] = MIN_F32
281281

282282
# N*ctc_beam
283283
score = gamma_n_h[start - 1]

aps/asr/beam_search/transducer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from copy import deepcopy
1414
from typing import List, Dict, Tuple, Optional
1515
from aps.utils import get_logger
16-
from aps.const import NEG_INF
16+
from aps.const import MIN_F32
1717
from aps.asr.beam_search.lm import lm_score_impl, LmType
1818

1919
logger = get_logger(__name__)
@@ -242,7 +242,7 @@ def forward(self,
242242
cache_dec_out.append(dec_out)
243243

244244
# set -inf as it already used
245-
cache_logp[best_idx][best_tok] = NEG_INF
245+
cache_logp[best_idx][best_tok] = MIN_F32
246246
# init as None and 0
247247
best_val, best_idx, best_tok = None, 0, 0
248248
for i, logp in enumerate(cache_logp):

aps/asr/beam_search/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass
99
from typing import List, Dict, Union, Tuple, Optional, NoReturn
1010
from aps.asr.beam_search.ctc import CtcScorer
11-
from aps.const import NEG_INF
11+
from aps.const import MIN_F32
1212
from aps.utils import get_logger
1313

1414
logger = get_logger(__name__)
@@ -135,7 +135,7 @@ def disable_eos(self, score: th.Tensor) -> th.Tensor:
135135
none_eos_best, _ = th.max(score[:, self.none_eos_idx], dim=-1)
136136
# set inf to disable the eos
137137
disable_eos = eos_prob < none_eos_best * self.param.eos_threshold
138-
score[disable_eos, self.param.eos] = NEG_INF
138+
score[disable_eos, self.param.eos] = MIN_F32
139139
if verbose and th.sum(disable_eos):
140140
disable_index = [i for i, s in enumerate(disable_eos) if s]
141141
logger.info(f"--- disable <eos> in beam index: {disable_index}")
@@ -152,7 +152,7 @@ def disable_unk(self, token: th.Tensor, score: th.Tensor) -> th.Tensor:
152152
"""
153153
if self.param.unk >= 0:
154154
unk_index = token == self.param.unk
155-
score[unk_index] = NEG_INF
155+
score[unk_index] = MIN_F32
156156
return score
157157

158158
def beam_select(self, am_prob: th.Tensor,
@@ -401,7 +401,7 @@ def _trace_back_hypos(self,
401401
point (Tensor): initial backward point
402402
"""
403403
score = self.score[point].tolist()
404-
self.acmu_score[point] = NEG_INF
404+
self.acmu_score[point] = MIN_F32
405405
return self.trace_hypos(point,
406406
score,
407407
self.trans,
@@ -611,7 +611,7 @@ def _trace_back_hypos(self,
611611
point (Tensor): initial backward point
612612
"""
613613
score = self.score[batch, point].tolist()
614-
self.acmu_score[batch, point] = NEG_INF
614+
self.acmu_score[batch, point] = MIN_F32
615615
trans = th.chunk(self.trans, self.batch_size, 0)[batch]
616616
align = th.chunk(self.align, self.batch_size, 0)[batch]
617617
points = [p[batch] for p in self.point]

aps/asr/transducers.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch as th
77
import torch.nn as nn
8-
import torch.nn.functional as tf
98

109
from typing import Optional, Dict, List
1110
from aps.asr.ctc import ASREncoderBase, NoneOrTensor, AMForwardType
@@ -107,7 +106,7 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
107106
Args:
108107
x_pad: N x Ti x D or N x S
109108
x_len: N or None
110-
y_pad: N x To
109+
y_pad: N x To (start with blank)
111110
y_len: N or None (not used here)
112111
Return:
113112
enc_out: N x Ti x D
@@ -116,10 +115,8 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
116115
"""
117116
# go through feature extractor & encoder
118117
enc_out, _, enc_len = self._training_prep(x_pad, x_len)
119-
# N x To+1
120-
tgt_pad = tf.pad(y_pad, (1, 0), value=self.blank)
121-
# N x Ti x To+1 x V
122-
dec_out = self.decoder(enc_out, tgt_pad)
118+
# N x Ti x To x V
119+
dec_out = self.decoder(enc_out, y_pad)
123120
return enc_out, dec_out, enc_len
124121

125122

@@ -134,7 +131,7 @@ def __init__(self,
134131
vocab_size: int = 40,
135132
asr_transform: Optional[nn.Module] = None,
136133
enc_type: str = "xfmr",
137-
enc_proj: Optional[int] = None,
134+
enc_proj: int = -1,
138135
enc_kwargs: Dict = {},
139136
dec_type: str = "xfmr",
140137
dec_kwargs: Dict = {}) -> None:
@@ -158,7 +155,7 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
158155
Args:
159156
x_pad: N x Ti x D or N x S
160157
x_len: N or None
161-
y_pad: N x To
158+
y_pad: N x To (start with blank)
162159
y_len: N or None
163160
Return:
164161
enc_out: N x Ti x D
@@ -167,8 +164,6 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
167164
"""
168165
# go through feature extractor & encoder
169166
enc_out, _, enc_len = self._training_prep(x_pad, x_len)
170-
# N x To+1
171-
tgt_pad = tf.pad(y_pad, (1, 0), value=self.blank)
172-
# N x Ti x To+1 x V
173-
dec_out = self.decoder(enc_out, tgt_pad, y_len + 1)
167+
# N x Ti x To x V
168+
dec_out = self.decoder(enc_out, y_pad, y_len)
174169
return enc_out, dec_out, enc_len

aps/asr/transformer/impl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from typing import Optional, Tuple, Dict, List
1414
from aps.libs import Register
15+
from aps.const import MIN_F32
1516
from aps.asr.transformer.utils import digit_shift, get_activation_fn, get_relative_uv
1617

1718
TransformerEncoderLayers = Register("xfmr_encoder_layer")
@@ -45,6 +46,7 @@ def __init__(self,
4546
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
4647
self.dropout = nn.Dropout(p=dropout)
4748
self.use_torch = use_torch
49+
self.min_f32 = MIN_F32
4850

4951
def inp_proj(self, query: th.Tensor, key: th.Tensor,
5052
value: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
@@ -101,8 +103,8 @@ def context_weight(self,
101103
"""
102104
logit = logit / (self.head_dim)**0.5
103105
if key_padding_mask is not None:
104-
logit = logit.masked_fill(key_padding_mask[None, :, None, :],
105-
float("-inf"))
106+
logit = th.masked_fill(logit, key_padding_mask[None, :, None, :],
107+
self.min_f32)
106108
if attn_mask is not None:
107109
logit += attn_mask[:, None, None, :]
108110
# L x N x H x S

aps/asr/transformer/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as tf
99

1010
from typing import Tuple
11+
from aps.const import NEG_INF
1112

1213

1314
def digit_shift(term: th.Tensor) -> th.Tensor:
@@ -53,14 +54,15 @@ def prep_sub_mask(num_frames: int, device: th.device = "cpu") -> th.Tensor:
5354
"""
5455
ones = th.ones(num_frames, num_frames, device=device)
5556
mask = (th.triu(ones, diagonal=1) == 1).float()
56-
mask = mask.masked_fill(mask == 1, float("-inf"))
57+
mask = mask.masked_fill(mask == 1, NEG_INF)
5758
return mask
5859

5960

6061
def prep_context_mask(num_frames: int,
6162
chunk_size: int = 1,
6263
lctx: int = 0,
6364
rctx: int = 0,
65+
ninf: float = NEG_INF,
6466
device: th.device = "cpu") -> th.Tensor:
6567
"""
6668
Prepare the square masks (-inf/0) for context masking
@@ -92,7 +94,7 @@ def prep_context_mask(num_frames: int,
9294
# generate masks
9395
zeros = th.zeros(num_frames, device=device)
9496
ctx_mask = th.logical_or(left_mask, right_mask)
95-
mask = zeros.masked_fill(ctx_mask, float("-inf"))
97+
mask = zeros.masked_fill(ctx_mask, ninf)
9698
return mask
9799

98100

aps/conf.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
"enh_transform", "asr_transform", "cmd_args"
1818
]
1919
all_lm_conf_keys = required_keys + ["cmd_args"]
20-
transducer_or_ctc_tasks = [
21-
"asr@transducer", "asr@ctc", "streaming_asr@transducer", "streaming_asr@ctc"
22-
]
20+
transducer_or_ctc_tasks = ["asr@transducer", "asr@ctc"]
2321

2422

2523
def load_dict(dict_path: str,
@@ -111,26 +109,24 @@ def load_am_conf(yaml_conf: str, dict_path: str) -> Tuple[Dict, Dict]:
111109
with open(yaml_conf, "r") as f:
112110
conf = yaml.full_load(f)
113111
conf = check_conf(conf, required_keys, all_am_conf_keys)
114-
115-
# add dict info
116112
nnet_conf = conf["nnet_conf"]
117113
is_transducer_or_ctc = conf["task"] in transducer_or_ctc_tasks
118114

115+
# load and add dict info
119116
required_units = [] if is_transducer_or_ctc else [EOS_TOKEN, EOS_TOKEN]
120117
vocab = load_dict(dict_path, required=required_units)
121118
nnet_conf["vocab_size"] = len(vocab)
122-
123119
# Generally we don't use eos/sos in
124120
if not is_transducer_or_ctc:
125121
nnet_conf["sos"] = vocab[SOS_TOKEN]
126122
nnet_conf["eos"] = vocab[EOS_TOKEN]
127123
# for transducer/CTC
128124
task_conf = conf["task_conf"]
129-
use_ctc = "ctc_weight" in task_conf and task_conf["ctc_weight"] > 0
130-
if use_ctc or is_transducer_or_ctc:
131-
conf["task_conf"]["blank"] = len(vocab)
125+
ctc_att_hybrid = "ctc_weight" in task_conf and task_conf["ctc_weight"] > 0
126+
if ctc_att_hybrid or is_transducer_or_ctc:
127+
task_conf["blank"] = len(vocab)
132128
# add blank
133129
nnet_conf["vocab_size"] += 1
134-
if use_ctc:
135-
nnet_conf["ctc"] = use_ctc
130+
if ctc_att_hybrid:
131+
nnet_conf["ctc"] = True
136132
return conf, vocab

0 commit comments

Comments
 (0)