Skip to content

Commit

Permalink
[EDIT] clean decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
thanhtvt committed Apr 27, 2023
1 parent 278ed30 commit fa4bc91
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
8 changes: 6 additions & 2 deletions uetasr/searchers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from .beam_search import BeamSearch, BeamRNNT
from .greedy import GreedySearch, GreedyRNNT, GreedyRNNTV2
from .alsd import ALSDBeamRNNT
from .alsd import ALSDBeamRNNT, ALSDSearch
from .tsd import TSDSearch


__all__ = [
"ALSDBeamRNNT",
"ALSDSearch",
"BeamSearch",
"BeamRNNT",
"GreedySearch",
"GreedyRNNT",
"GreedyRNNTV2",
"ALSDBeamRNNT"
"TSDSearch",
]
18 changes: 11 additions & 7 deletions uetasr/searchers/alsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def search(self, enc_out: tf.Tensor,
beam_dec_out, beam_state = self.decoder.infer(labels,
states=beam_state,
training=False)
beam_env_out = tf.stack([x[1] for x in B_enc_out])
beam_enc_out = tf.stack([x[1] for x in B_enc_out])

beam_logp = tf.nn.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
Expand Down Expand Up @@ -141,7 +141,6 @@ def search(self, enc_out: tf.Tensor,
else:
return self.sort_nbest(B)


def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Recombine hypotheses with same label ID sequence.
Expand Down Expand Up @@ -174,25 +173,28 @@ class ALSDBeamRNNT(tf.keras.layers.Layer):
"""

def __init__(self,
model: tf.keras.Model,
decoder: tf.keras.Model,
jointer: tf.keras.Model,
text_decoder: PreprocessingLayer,
fraction: float = 0.65,
beam_size: int = 16,
temperature: float = 1.0,
use_lm: bool = False,
lmwt: float = 0.5,
lm_path: str = '',
return_scores: bool = True,
name: str = 'alsd_rnnt',
**kwargs):
super().__init__(name=name, **kwargs)
self.decoder = model.decoder
self.jointer = model.jointer
self.decoder = decoder
self.jointer = jointer
self.text_decoder = text_decoder
self.blank_id = text_decoder.pad_id
self.beam_size = beam_size
self.temperature = temperature
self.fraction = fraction
self.use_lm = use_lm
self.return_scores = return_scores

if use_lm:
self.lm = kenlm.LanguageModel(lm_path)
Expand Down Expand Up @@ -357,7 +359,6 @@ def infer(self, encoder_outputs: tf.Tensor,
hyps = tf.concat([best_hyps, hyp], axis=1)
scores = self.recombine_hypotheses(hyps, scores)

new_cur_states = []
cur_states = tf.where(_equal, cur_states, next_states)

i = tf.where(i < total_lengths - 1, i + 1, i)
Expand Down Expand Up @@ -388,7 +389,10 @@ def infer(self, encoder_outputs: tf.Tensor,
else:
preds = tf.zeros([batch_size, 1], dtype=tf.int32)

return self.text_decoder.decode(preds), best_scores
if self.return_scores:
return self.text_decoder.decode(preds), best_scores
else:
return self.text_decoder.decode(preds)

def recombine_hypotheses(self, hyps, scores):
"""
Expand Down
8 changes: 6 additions & 2 deletions uetasr/searchers/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import \
Expand Down Expand Up @@ -164,6 +163,7 @@ def __init__(self,
lmwt: float = 0.3,
lm_path: str = '',
lm: tf.keras.Model = None,
return_scores: bool = True,
name: str = 'beam_rnnt',
**kwargs):
super().__init__(name=name, **kwargs)
Expand All @@ -179,6 +179,7 @@ def __init__(self,
self.lm = lm
self.blank_id = text_decoder.pad_id
self.max_symbols_per_step = max_symbols_per_step
self.return_scores = return_scores

def infer(self, encoder_outputs: Union[tf.Tensor, np.ndarray],
encoder_lengths: Union[tf.Tensor, np.ndarray]) -> tf.Tensor:
Expand Down Expand Up @@ -401,4 +402,7 @@ def infer(self, encoder_outputs: Union[tf.Tensor, np.ndarray],
best_scores = tf.reduce_max(scores, axis=-1)
outputs = self.text_decoder.decode(best_hyps)

return outputs, best_scores
if self.return_scores:
return outputs, best_scores
else:
return outputs
7 changes: 6 additions & 1 deletion uetasr/searchers/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self,
jointer: tf.keras.Model,
text_decoder: PreprocessingLayer,
max_symbols_per_step: int = 3,
return_scores: bool = True,
name: str = 'greedy_rnnt',
**kwargs):
super().__init__(name=name, **kwargs)
Expand All @@ -97,6 +98,7 @@ def __init__(self,
self.text_decoder = text_decoder
self.blank_id = text_decoder.pad_id
self.max_symbols_per_step = max_symbols_per_step
self.return_scores = return_scores

@tf.function
def call(
Expand Down Expand Up @@ -192,7 +194,10 @@ def infer(
if _equal.numpy().all():
break

return self.text_decoder.decode(hyps), scores
if self.return_scores:
return self.text_decoder.decode(hyps), scores
else:
return self.text_decoder.decode(hyps)

def infer_step(self,
encoder_outputs: tf.Tensor,
Expand Down
5 changes: 2 additions & 3 deletions uetasr/searchers/tsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import \
PreprocessingLayer
from typing import List, Union
from typing import List

from .hypothesis import Hypothesis
from .base import BaseSearch
from ..utils.common import get_shape


class TSDSearch(BaseSearch):
Expand Down Expand Up @@ -37,7 +36,7 @@ def __init__(
name: str = "alsd_search",
**kwargs,
):
super(ALSDSearch, self).__init__(
super(TSDSearch, self).__init__(
decoder=decoder,
joint_network=joint_network,
text_decoder=text_decoder,
Expand Down

0 comments on commit fa4bc91

Please sign in to comment.