Skip to content

Commit

Permalink
Fix nan issues in streaming transformer encoder (#53)
Browse files Browse the repository at this point in the history
* [aps,tests,cmd] fix nan issues in streaming transformer encoder

* [aps] move sos padding from ASR forward to task class

* [aps,examples] add chimera++ nnet & loss

* [aps,conf,examples] update wham recipe

* [aps,conf,examples] update wham results

* [aps,tests] fix ci errors
  • Loading branch information
funcwj authored Nov 29, 2021
1 parent 571c845 commit c814dc5
Show file tree
Hide file tree
Showing 45 changed files with 1,143 additions and 618 deletions.
17 changes: 6 additions & 11 deletions aps/asr/att.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch as th
import torch.nn as nn
import torch.nn.functional as tf

import aps.asr.beam_search.att as att_api
import aps.asr.beam_search.transformer as xfmr_api
Expand Down Expand Up @@ -98,7 +97,7 @@ def forward(self,
Args:
x_pad: N x Ti x D or N x S
x_len: N or None
y_pad: N x To
y_pad: N x To (start with sos)
y_len: N or None, not used here
ssr: schedule sampling rate
Return:
Expand All @@ -110,13 +109,11 @@ def forward(self,
self.att_net.clear()
# go through feature extractor & encoder
enc_out, enc_ctc, enc_len = self._training_prep(x_pad, x_len)
# N x To+1
tgt_pad = tf.pad(y_pad, (1, 0), value=self.sos)
# N x (To+1), pad SOS
# N x To
dec_out, _ = self.decoder(self.att_net,
enc_out,
enc_len,
tgt_pad,
y_pad,
schedule_sampling=ssr)
return dec_out, enc_ctc, enc_len

Expand Down Expand Up @@ -262,7 +259,7 @@ def forward(self,
Args:
x_pad: N x Ti x D or N x S
x_len: N or None
y_pad: N x To
y_pad: N x To (start with eos)
y_len: N or None
ssr: not used here, left for future
Return:
Expand All @@ -272,10 +269,8 @@ def forward(self,
"""
# go through feature extractor & encoder
enc_out, enc_ctc, enc_len = self._training_prep(x_pad, x_len)
# N x To+1
tgt_pad = tf.pad(y_pad, (1, 0), value=self.sos)
# N x To+1 x D
dec_out = self.decoder(enc_out, enc_len, tgt_pad, y_len + 1)
# N x To x D
dec_out = self.decoder(enc_out, enc_len, y_pad, y_len)
return dec_out, enc_ctc, enc_len

def greedy_search(self,
Expand Down
3 changes: 2 additions & 1 deletion aps/asr/base/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as tf

from typing import Optional, Tuple
from aps.const import NEG_INF
from aps.libs import Register

AsrAtt = Register("asr_att")
Expand Down Expand Up @@ -65,7 +66,7 @@ def softmax(self, score: th.Tensor, enc_len: Optional[th.Tensor],
if pad_mask is None:
raise RuntimeError("Attention: pad_mask should not be None "
"when enc_len is not None")
score = score.masked_fill(pad_mask, float("-inf"))
score = score.masked_fill(pad_mask, NEG_INF)
return tf.softmax(score, dim=-1)

def clear(self):
Expand Down
10 changes: 5 additions & 5 deletions aps/asr/base/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class RNNEncoderBase(EncoderBase):
def __init__(self,
inp_features: int,
out_features: int,
rnns: nn.Module,
impl: nn.Module,
input_proj: int = -1,
hidden: int = -1,
bidirectional: bool = False,
Expand All @@ -105,7 +105,7 @@ def __init__(self,
self.proj = nn.Linear(inp_features, input_proj)
else:
self.proj = None
self.rnns = rnns
self.impl = impl
factor = 2 if bidirectional else 1
if out_features > 0:
self.outp = nn.Linear(hidden * factor, out_features)
Expand Down Expand Up @@ -173,11 +173,11 @@ def __init__(self,
non_linear=non_linear)

def flat(self):
self.rnns.flatten_parameters()
self.impl.flatten_parameters()

def _forward(self, inp: th.Tensor,
inp_len: Optional[th.Tensor]) -> th.Tensor:
return var_len_rnn_forward(self.rnns,
return var_len_rnn_forward(self.impl,
inp,
inp_len=inp_len,
enforce_sorted=False,
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self,

def _forward(self, inp: th.Tensor,
inp_len: Optional[th.Tensor]) -> th.Tensor:
return self.rnns(inp)[0]
return self.impl(inp)[0]


@BaseEncoder.register("variant_rnn")
Expand Down
16 changes: 8 additions & 8 deletions aps/asr/beam_search/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

from collections import defaultdict
from aps.const import NEG_INF
from aps.const import MIN_F32
from aps.utils import get_logger
from typing import Dict, List, Union

Expand Down Expand Up @@ -56,7 +56,7 @@ def beam_search(self,
topk_score, topk_token = th.topk(ctc_prob, beam_size, -1)
T, V = ctc_prob.shape
logger.info(f"--- shape of the encoder output (CTC): {T} x {V}")
neg_inf = th.tensor(NEG_INF).to(ctc_prob.device)
neg_inf = th.tensor(MIN_F32).to(ctc_prob.device)
zero = th.tensor(0.0).to(ctc_prob.device)
# (prefix, log_pb, log_pn)
# NOTE: actually do not need sos/eos here, just place it in the sentence
Expand Down Expand Up @@ -131,7 +131,7 @@ def viterbi_align(self, ctc_enc: th.Tensor, dec_seq: th.Tensor) -> Dict:

dec_seq = dec_seq.tolist()
# T x U*2+1
score = NEG_INF * th.ones(T, U * 2 + 1, device=ctc_prob.device)
score = MIN_F32 * th.ones(T, U * 2 + 1, device=ctc_prob.device)
point = -1 * th.ones(
T, U * 2 + 1, dtype=th.int32, device=ctc_prob.device)

Expand Down Expand Up @@ -223,9 +223,9 @@ def __init__(self,
self.blank = -1
self.offset = th.arange(batch_size, device=self.device)
self.beam_size = beam_size
# eq (51) NEG_INF ~ log(0), T x N
# eq (51) MIN_F32 ~ log(0), T x N
self.gamma_n_g = th.full((self.T, batch_size),
NEG_INF,
MIN_F32,
device=self.device)
self.gamma_b_g = th.zeros(self.T, batch_size, device=self.device)
# eq (52)
Expand All @@ -235,7 +235,7 @@ def __init__(self,
self.ctc_prob[t, self.blank])
# ctc score in previous steps
self.ctc_score = th.zeros(1, batch_size, device=self.device)
self.neg_inf = th.tensor(NEG_INF).to(self.device)
self.neg_inf = th.tensor(MIN_F32).to(self.device)

def update_var(self, point: Union[th.Tensor, int]) -> None:
"""
Expand Down Expand Up @@ -276,8 +276,8 @@ def forward(self, g: th.Tensor, c: th.Tensor) -> th.Tensor:
# zero based
glen = g.shape[-1] - 1
start = max(glen, 1)
gamma_n_h[start - 1] = self.ctc_prob[0, c] if glen == 0 else NEG_INF
gamma_b_h[start - 1] = NEG_INF
gamma_n_h[start - 1] = self.ctc_prob[0, c] if glen == 0 else MIN_F32
gamma_b_h[start - 1] = MIN_F32

# N*ctc_beam
score = gamma_n_h[start - 1]
Expand Down
4 changes: 2 additions & 2 deletions aps/asr/beam_search/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from copy import deepcopy
from typing import List, Dict, Tuple, Optional
from aps.utils import get_logger
from aps.const import NEG_INF
from aps.const import MIN_F32
from aps.asr.beam_search.lm import lm_score_impl, LmType

logger = get_logger(__name__)
Expand Down Expand Up @@ -242,7 +242,7 @@ def forward(self,
cache_dec_out.append(dec_out)

# set -inf as it already used
cache_logp[best_idx][best_tok] = NEG_INF
cache_logp[best_idx][best_tok] = MIN_F32
# init as None and 0
best_val, best_idx, best_tok = None, 0, 0
for i, logp in enumerate(cache_logp):
Expand Down
10 changes: 5 additions & 5 deletions aps/asr/beam_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from typing import List, Dict, Union, Tuple, Optional, NoReturn
from aps.asr.beam_search.ctc import CtcScorer
from aps.const import NEG_INF
from aps.const import MIN_F32
from aps.utils import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -135,7 +135,7 @@ def disable_eos(self, score: th.Tensor) -> th.Tensor:
none_eos_best, _ = th.max(score[:, self.none_eos_idx], dim=-1)
# set inf to disable the eos
disable_eos = eos_prob < none_eos_best * self.param.eos_threshold
score[disable_eos, self.param.eos] = NEG_INF
score[disable_eos, self.param.eos] = MIN_F32
if verbose and th.sum(disable_eos):
disable_index = [i for i, s in enumerate(disable_eos) if s]
logger.info(f"--- disable <eos> in beam index: {disable_index}")
Expand All @@ -152,7 +152,7 @@ def disable_unk(self, token: th.Tensor, score: th.Tensor) -> th.Tensor:
"""
if self.param.unk >= 0:
unk_index = token == self.param.unk
score[unk_index] = NEG_INF
score[unk_index] = MIN_F32
return score

def beam_select(self, am_prob: th.Tensor,
Expand Down Expand Up @@ -401,7 +401,7 @@ def _trace_back_hypos(self,
point (Tensor): initial backward point
"""
score = self.score[point].tolist()
self.acmu_score[point] = NEG_INF
self.acmu_score[point] = MIN_F32
return self.trace_hypos(point,
score,
self.trans,
Expand Down Expand Up @@ -611,7 +611,7 @@ def _trace_back_hypos(self,
point (Tensor): initial backward point
"""
score = self.score[batch, point].tolist()
self.acmu_score[batch, point] = NEG_INF
self.acmu_score[batch, point] = MIN_F32
trans = th.chunk(self.trans, self.batch_size, 0)[batch]
align = th.chunk(self.align, self.batch_size, 0)[batch]
points = [p[batch] for p in self.point]
Expand Down
19 changes: 7 additions & 12 deletions aps/asr/transducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch as th
import torch.nn as nn
import torch.nn.functional as tf

from typing import Optional, Dict, List
from aps.asr.ctc import ASREncoderBase, NoneOrTensor, AMForwardType
Expand Down Expand Up @@ -107,7 +106,7 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
Args:
x_pad: N x Ti x D or N x S
x_len: N or None
y_pad: N x To
y_pad: N x To (start with blank)
y_len: N or None (not used here)
Return:
enc_out: N x Ti x D
Expand All @@ -116,10 +115,8 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
"""
# go through feature extractor & encoder
enc_out, _, enc_len = self._training_prep(x_pad, x_len)
# N x To+1
tgt_pad = tf.pad(y_pad, (1, 0), value=self.blank)
# N x Ti x To+1 x V
dec_out = self.decoder(enc_out, tgt_pad)
# N x Ti x To x V
dec_out = self.decoder(enc_out, y_pad)
return enc_out, dec_out, enc_len


Expand All @@ -134,7 +131,7 @@ def __init__(self,
vocab_size: int = 40,
asr_transform: Optional[nn.Module] = None,
enc_type: str = "xfmr",
enc_proj: Optional[int] = None,
enc_proj: int = -1,
enc_kwargs: Dict = {},
dec_type: str = "xfmr",
dec_kwargs: Dict = {}) -> None:
Expand All @@ -158,7 +155,7 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
Args:
x_pad: N x Ti x D or N x S
x_len: N or None
y_pad: N x To
y_pad: N x To (start with blank)
y_len: N or None
Return:
enc_out: N x Ti x D
Expand All @@ -167,8 +164,6 @@ def forward(self, x_pad: th.Tensor, x_len: NoneOrTensor, y_pad: th.Tensor,
"""
# go through feature extractor & encoder
enc_out, _, enc_len = self._training_prep(x_pad, x_len)
# N x To+1
tgt_pad = tf.pad(y_pad, (1, 0), value=self.blank)
# N x Ti x To+1 x V
dec_out = self.decoder(enc_out, tgt_pad, y_len + 1)
# N x Ti x To x V
dec_out = self.decoder(enc_out, y_pad, y_len)
return enc_out, dec_out, enc_len
6 changes: 4 additions & 2 deletions aps/asr/transformer/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from typing import Optional, Tuple, Dict, List
from aps.libs import Register
from aps.const import MIN_F32
from aps.asr.transformer.utils import digit_shift, get_activation_fn, get_relative_uv

TransformerEncoderLayers = Register("xfmr_encoder_layer")
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self,
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.dropout = nn.Dropout(p=dropout)
self.use_torch = use_torch
self.min_f32 = MIN_F32

def inp_proj(self, query: th.Tensor, key: th.Tensor,
value: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
Expand Down Expand Up @@ -101,8 +103,8 @@ def context_weight(self,
"""
logit = logit / (self.head_dim)**0.5
if key_padding_mask is not None:
logit = logit.masked_fill(key_padding_mask[None, :, None, :],
float("-inf"))
logit = th.masked_fill(logit, key_padding_mask[None, :, None, :],
self.min_f32)
if attn_mask is not None:
logit += attn_mask[:, None, None, :]
# L x N x H x S
Expand Down
6 changes: 4 additions & 2 deletions aps/asr/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as tf

from typing import Tuple
from aps.const import NEG_INF


def digit_shift(term: th.Tensor) -> th.Tensor:
Expand Down Expand Up @@ -53,14 +54,15 @@ def prep_sub_mask(num_frames: int, device: th.device = "cpu") -> th.Tensor:
"""
ones = th.ones(num_frames, num_frames, device=device)
mask = (th.triu(ones, diagonal=1) == 1).float()
mask = mask.masked_fill(mask == 1, float("-inf"))
mask = mask.masked_fill(mask == 1, NEG_INF)
return mask


def prep_context_mask(num_frames: int,
chunk_size: int = 1,
lctx: int = 0,
rctx: int = 0,
ninf: float = NEG_INF,
device: th.device = "cpu") -> th.Tensor:
"""
Prepare the square masks (-inf/0) for context masking
Expand Down Expand Up @@ -92,7 +94,7 @@ def prep_context_mask(num_frames: int,
# generate masks
zeros = th.zeros(num_frames, device=device)
ctx_mask = th.logical_or(left_mask, right_mask)
mask = zeros.masked_fill(ctx_mask, float("-inf"))
mask = zeros.masked_fill(ctx_mask, ninf)
return mask


Expand Down
18 changes: 7 additions & 11 deletions aps/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"enh_transform", "asr_transform", "cmd_args"
]
all_lm_conf_keys = required_keys + ["cmd_args"]
transducer_or_ctc_tasks = [
"asr@transducer", "asr@ctc", "streaming_asr@transducer", "streaming_asr@ctc"
]
transducer_or_ctc_tasks = ["asr@transducer", "asr@ctc"]


def load_dict(dict_path: str,
Expand Down Expand Up @@ -111,26 +109,24 @@ def load_am_conf(yaml_conf: str, dict_path: str) -> Tuple[Dict, Dict]:
with open(yaml_conf, "r") as f:
conf = yaml.full_load(f)
conf = check_conf(conf, required_keys, all_am_conf_keys)

# add dict info
nnet_conf = conf["nnet_conf"]
is_transducer_or_ctc = conf["task"] in transducer_or_ctc_tasks

# load and add dict info
required_units = [] if is_transducer_or_ctc else [EOS_TOKEN, EOS_TOKEN]
vocab = load_dict(dict_path, required=required_units)
nnet_conf["vocab_size"] = len(vocab)

# Generally we don't use eos/sos in
if not is_transducer_or_ctc:
nnet_conf["sos"] = vocab[SOS_TOKEN]
nnet_conf["eos"] = vocab[EOS_TOKEN]
# for transducer/CTC
task_conf = conf["task_conf"]
use_ctc = "ctc_weight" in task_conf and task_conf["ctc_weight"] > 0
if use_ctc or is_transducer_or_ctc:
conf["task_conf"]["blank"] = len(vocab)
ctc_att_hybrid = "ctc_weight" in task_conf and task_conf["ctc_weight"] > 0
if ctc_att_hybrid or is_transducer_or_ctc:
task_conf["blank"] = len(vocab)
# add blank
nnet_conf["vocab_size"] += 1
if use_ctc:
nnet_conf["ctc"] = use_ctc
if ctc_att_hybrid:
nnet_conf["ctc"] = True
return conf, vocab
Loading

0 comments on commit c814dc5

Please sign in to comment.