diff --git a/egs/wsj/s5/steps/decode_combine_test.sh b/egs/wsj/s5/steps/decode_combine_test.sh new file mode 100755 index 00000000000..7d53f67faad --- /dev/null +++ b/egs/wsj/s5/steps/decode_combine_test.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# Begin configuration. +nj=4 +cmd=run.pl +maxactive=7000 +beam=15.0 +lattice_beam=8.0 +expand_beam=30.0 +acwt=1.0 +skip_scoring=false +combine_version=false + +stage=0 +online_ivector_dir= +post_decode_acwt=10.0 +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=0 +extra_right_context_final=0 +chunk_width=140,100,160 +use_gpu=no +# End configuration. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: steps/decode_combine_test.sh [options] " + echo "... where is assumed to be a sub-directory of the directory" + echo " where the model is." + echo "e.g.: steps/decode_combine_test.sh exp/mono/graph_tgpar data/test_dev93 exp/mono/decode_dev93_tgpr" + echo "" + echo "This script works on CMN + (delta+delta-delta | LDA+MLLT) features; it works out" + echo "what type of features you used (assuming it's one of these two)" + echo "" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + + +graphdir=$1 +data=$2 +dir=$3 + +srcdir=`dirname $dir`; # The model directory is one level up from decoding directory. +sdata=$data/split$nj; +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` +delta_opts=`cat $srcdir/delta_opts 2>/dev/null` + +mkdir -p $dir/log +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +for f in $sdata/1/feats.scp $sdata/1/cmvn.scp $srcdir/final.mdl $graphdir/HCLG.fst; do + [ ! -f $f ] && echo "decode_combine_test.sh: no such file $f" && exit 1; +done + + +if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "decode_combine_test.sh: feature type is $feat_type" + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +posteriors="ark,scp:$sdata/JOB/posterior.ark,$sdata/JOB/posterior.scp" +posteriors_scp="scp:$sdata/JOB/posterior.scp" + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +if [ "$post_decode_acwt" == 1.0 ]; then + lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz" +else + lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" +fi + +frame_subsampling_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +fi + +frames_per_chunk=$(echo $chunk_width | cut -d, -f1) +# generate log-likelihood +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/nnet_compute.JOB.log \ + nnet3-compute $ivector_opts $frame_subsampling_opt \ + --acoustic-scale=$acwt \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + --frames-per-chunk=$frames_per_chunk \ + --use-gpu=$use_gpu --use-priors=true \ + $srcdir/final.mdl "$feats" "$posteriors" +fi + +if [ $stage -le 2 ]; then + suffix= + if $combine_version ; then + suffix="-combine" + fi + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + latgen-faster-mapped$suffix --max-active=$maxactive --beam=$beam --lattice-beam=$lattice_beam \ + --acoustic-scale=$acwt --allow-partial=true --word-symbol-table=$graphdir/words.txt \ + $srcdir/final.mdl $graphdir/HCLG.fst "$posteriors_scp" "$lat_wspecifier" || exit 1; +fi + +if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "Not scoring because local/score.sh does not exist or not executable." && exit 1; + local/score.sh --cmd "$cmd" $data $graphdir $dir || + { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } +fi + +exit 0; diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b50120..3fcbef6ad32 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector latgen-faster-mapped-combine OBJFILES = diff --git a/src/bin/latgen-faster-mapped-combine.cc b/src/bin/latgen-faster-mapped-combine.cc new file mode 100644 index 00000000000..ae5946d9e8e --- /dev/null +++ b/src/bin/latgen-faster-mapped-combine.cc @@ -0,0 +1,179 @@ +// bin/latgen-faster-mapped.cc + +// Copyright 2009-2012 Microsoft Corporation, Karel Vesely +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "Usage: latgen-faster-mapped [options] trans-model-in (fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeFasterDecoderCombineConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), + lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeFasterDecoderCombine decoder(*decode_fst, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes (loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + + double like; + if (DecodeUtteranceLatticeFasterCombine( + decoder, decodable, trans_model, word_syms, utt, + acoustic_scale, determinize, allow_partial, &alignment_writer, + &words_writer, &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeFasterDecoderCombine decoder(fst_reader.Value(), config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeFasterCombine( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 020fe358fe9..53d469f4860 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,7 +7,7 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o lattice-faster-decoder-combine.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index ff573c74d15..3c1dbd7ed8d 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -546,4 +546,303 @@ void AlignUtteranceWrapper( } } +// For lattice-faster-decoder-combine +DecodeUtteranceLatticeFasterCombineClass::DecodeUtteranceLatticeFasterCombineClass( + LatticeFasterDecoderCombine *decoder, + DecodableInterface *decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + BaseFloat acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_sum, // on success, adds likelihood to this. + int64 *frame_sum, // on success, adds #frames to this. + int32 *num_done, // on success (including partial decode), increments this. + int32 *num_err, // on failure, increments this. + int32 *num_partial): // If partial decode (final-state not reached), increments this. + decoder_(decoder), decodable_(decodable), trans_model_(&trans_model), + word_syms_(word_syms), utt_(utt), acoustic_scale_(acoustic_scale), + determinize_(determinize), allow_partial_(allow_partial), + alignments_writer_(alignments_writer), + words_writer_(words_writer), + compact_lattice_writer_(compact_lattice_writer), + lattice_writer_(lattice_writer), + like_sum_(like_sum), frame_sum_(frame_sum), + num_done_(num_done), num_err_(num_err), + num_partial_(num_partial), + computed_(false), success_(false), partial_(false), + clat_(NULL), lat_(NULL) { } + + +void DecodeUtteranceLatticeFasterCombineClass::operator () () { + // Decoding and lattice determinization happens here. + computed_ = true; // Just means this function was called-- a check on the + // calling code. + success_ = true; + using fst::VectorFst; + if (!decoder_->Decode(decodable_)) { + KALDI_WARN << "Failed to decode file " << utt_; + success_ = false; + } + if (!decoder_->ReachedFinal()) { + if (allow_partial_) { + KALDI_WARN << "Outputting partial output for utterance " << utt_ + << " since no final-state reached\n"; + partial_ = true; + } else { + KALDI_WARN << "Not producing output for utterance " << utt_ + << " since no final-state reached and " + << "--allow-partial=false.\n"; + success_ = false; + } + } + if (!success_) return; + + // Get lattice, and do determinization if requested. + lat_ = new Lattice; + decoder_->GetRawLattice(lat_); + if (lat_->NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt_; + fst::Connect(lat_); + if (determinize_) { + clat_ = new CompactLattice; + if (!DeterminizeLatticePhonePrunedWrapper( + *trans_model_, + lat_, + decoder_->GetOptions().lattice_beam, + clat_, + decoder_->GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt_; + delete lat_; + lat_ = NULL; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale_ != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), clat_); + } else { + // We'll write the lattice without acoustic scaling. + if (acoustic_scale_ != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), lat_); + } +} + +DecodeUtteranceLatticeFasterCombineClass::~DecodeUtteranceLatticeFasterCombineClass() { + if (!computed_) + KALDI_ERR << "Destructor called without operator (), error in calling code."; + + if (!success_) { + if (num_err_ != NULL) (*num_err_)++; + } else { // successful decode. + // Getting the one-best output is lightweight enough that we can do it in + // the destructor (easier than adding more variables to the class, and + // will rarely slow down the main thread.) + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + // This is basically for diagnostics. + fst::VectorFst decoded; + decoder_->GetBestPath(&decoded); + if (decoded.NumStates() == 0) { + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt_; + } + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer_->IsOpen()) + words_writer_->Write(utt_, words); + if (alignments_writer_->IsOpen()) + alignments_writer_->Write(utt_, alignment); + if (word_syms_ != NULL) { + std::cerr << utt_ << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms_->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Ouptut the lattices. + if (determinize_) { // CompactLattice output. + KALDI_ASSERT(compact_lattice_writer_ != NULL && clat_ != NULL); + if (clat_->NumStates() == 0) { + KALDI_WARN << "Empty lattice for utterance " << utt_; + } else { + compact_lattice_writer_->Write(utt_, *clat_); + } + delete clat_; + clat_ = NULL; + } else { + KALDI_ASSERT(lattice_writer_ != NULL && lat_ != NULL); + if (lat_->NumStates() == 0) { + KALDI_WARN << "Empty lattice for utterance " << utt_; + } else { + lattice_writer_->Write(utt_, *lat_); + } + delete lat_; + lat_ = NULL; + } + + // Print out logging information. + KALDI_LOG << "Log-like per frame for utterance " << utt_ << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt_ << " is " + << weight.Value1() << " + " << weight.Value2(); + + // Now output the various diagnostic variables. + if (like_sum_ != NULL) *like_sum_ += likelihood; + if (frame_sum_ != NULL) *frame_sum_ += num_frames; + if (num_done_ != NULL) (*num_done_)++; + if (partial_ && num_partial_ != NULL) (*num_partial_)++; + } + // We were given ownership of these two objects that were passed in in + // the initializer. + delete decoder_; + delete decodable_; +} + + +// Takes care of output. Returns true on success. +template +bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + if (!decoder.GetBestPath(&decoded)) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); + lattice_writer->Write(utt, lat); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +// Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + } // end namespace kaldi. diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index fc81137f356..19d01e4316a 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -23,6 +23,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" #include "decoder/lattice-simple-decoder.h" +#include "decoder/lattice-faster-decoder-combine.h" // This header contains declarations from various convenience functions that are called // from binary-level programs such as gmm-decode-faster.cc, gmm-align-compiled.cc, and @@ -196,6 +197,78 @@ bool DecodeUtteranceLatticeSimple( double *like_ptr); // puts utterance's likelihood in like_ptr on success. +// For lattice-faster-decoder-combine +template +bool DecodeUtteranceLatticeFasterCombine( + LatticeFasterDecoderCombineTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + + +class DecodeUtteranceLatticeFasterCombineClass { + public: + // Initializer sets various variables. + // NOTE: we "take ownership" of "decoder" and "decodable". These + // are deleted by the destructor. On error, "num_err" is incremented. + DecodeUtteranceLatticeFasterCombineClass( + LatticeFasterDecoderCombine *decoder, + DecodableInterface *decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + BaseFloat acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_sum, // on success, adds likelihood to this. + int64 *frame_sum, // on success, adds #frames to this. + int32 *num_done, // on success (including partial decode), increments this. + int32 *num_err, // on failure, increments this. + int32 *num_partial); // If partial decode (final-state not reached), increments this. + void operator () (); // The decoding happens here. + ~DecodeUtteranceLatticeFasterCombineClass(); // Output happens here. + private: + // The following variables correspond to inputs: + LatticeFasterDecoderCombine *decoder_; + DecodableInterface *decodable_; + const TransitionModel *trans_model_; + const fst::SymbolTable *word_syms_; + std::string utt_; + BaseFloat acoustic_scale_; + bool determinize_; + bool allow_partial_; + Int32VectorWriter *alignments_writer_; + Int32VectorWriter *words_writer_; + CompactLatticeWriter *compact_lattice_writer_; + LatticeWriter *lattice_writer_; + double *like_sum_; + int64 *frame_sum_; + int32 *num_done_; + int32 *num_err_; + int32 *num_partial_; + + // The following variables are stored by the computation. + bool computed_; // operator () was called. + bool success_; // decoding succeeded (possibly partial) + bool partial_; // decoding was partial. + CompactLattice *clat_; // Stored output, if determinize_ == true. + Lattice *lat_; // Stored output, if determinize_ == false. +}; + + } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc new file mode 100644 index 00000000000..f30fc36b872 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.cc @@ -0,0 +1,1088 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +template +BucketQueue::BucketQueue(BaseFloat cost_scale) : + cost_scale_(cost_scale) { + // NOTE: we reserve plenty of elements to avoid expensive reallocations + // later on. Normally, the size is a little bigger than (adaptive_beam + + // 15) * cost_scale. + int32 bucket_size = (15 + 20) * cost_scale_; + buckets_.resize(bucket_size); + bucket_offset_ = 15 * cost_scale_; + first_nonempty_bucket_index_ = bucket_size - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + bucket_size_tolerance_ = 1.2 * bucket_size; +} + +template +void BucketQueue::Push(Token *tok) { + size_t bucket_index = std::floor(tok->tot_cost * cost_scale_) + + bucket_offset_; + if (bucket_index >= buckets_.size()) { + int32 margin = 10; // a margin which is used to reduce re-allocate + // space frequently + if (static_cast(bucket_index) > 0) { + buckets_.resize(bucket_index + margin); + } else { // less than 0 + int32 increase_size = - static_cast(bucket_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_offset_ = bucket_offset_ + increase_size; + bucket_index += increase_size; + first_nonempty_bucket_index_ = bucket_index; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } + tok->in_queue = true; + buckets_[bucket_index].push_back(tok); + if (bucket_index < first_nonempty_bucket_index_) { + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } +} + +template +Token* BucketQueue::Pop() { + while (true) { + if (!first_nonempty_bucket_->empty()) { + Token *ans = first_nonempty_bucket_->back(); + first_nonempty_bucket_->pop_back(); + if (ans->in_queue) { // If ans->in_queue is false, this means it is a + // duplicate instance of this Token that was left + // over when a Token's best_cost changed, and the + // Token has already been processed(so conceptually, + // it is not in the queue). + ans->in_queue = false; + return ans; + } + } + if (first_nonempty_bucket_->empty()) { + for (; first_nonempty_bucket_index_ + 1 < buckets_.size(); + first_nonempty_bucket_index_++) { + if (!buckets_[first_nonempty_bucket_index_].empty()) break; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + if (first_nonempty_bucket_->empty()) return NULL; + } + } +} + +template +void BucketQueue::Clear() { + for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { + buckets_[i].clear(); + } + if (buckets_.size() > bucket_size_tolerance_) { + buckets_.resize(bucket_size_tolerance_); + bucket_offset_ = 15 * cost_scale_; + } + first_nonempty_bucket_index_ = buckets_.size() - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; +} + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), + cur_queue_(config_.cost_scale) { + config.Check(); + cur_toks_.reserve(1000); + next_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0), + cur_queue_(config_.cost_scale) { + config.Check(); + cur_toks_.reserve(1000); + next_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + cur_toks_.clear(); + next_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + next_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; + adaptive_beam_ = config_.beam; + cost_offsets_.resize(1); + cost_offsets_[0] = 0.0; + +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + if (!decoding_finalized_ && use_final_probs) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + return (ofst->NumStates() > 0); +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 cur_frame = active_toks_.size() - 1, // frame is the frame-index (zero- + // based) used to get likelihoods + // from the decodable object. + next_frame = cur_frame + 1; + + active_toks_.resize(active_toks_.size() + 1); + + cur_toks_.swap(next_toks_); + next_toks_.clear(); + if (cur_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << cur_frame; + warned_ = true; + } + } + + cur_queue_.Clear(); + // Add tokens to queue + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) + cur_queue_.Push(tok); + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + // this will be kept updated to the best-seen-so-far token "on next frame" + // + adaptive_beam + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = cost_offsets_[cur_frame]; + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + Token *tok = NULL; + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff && + num_toks_processed > config_.min_active) { // Don't bother processing + // successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed) { + cur_queue_.Push(new_tok); + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(cur_frame, + arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) { + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for + // emitting + } + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, next_frame, tot_cost, + tok, &next_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". + cost_offsets_.resize(cur_frame + 2, 0.0); + cost_offsets_[next_frame] = adaptive_beam - next_cutoff; + + { // This block updates adaptive_beam_ + BaseFloat beam_used_this_frame = adaptive_beam; + Token *tok = cur_queue_.Pop(); + if (tok != NULL) { + // We hit the max-active contraint, meaning we effectively pruned to a + // beam tighter than 'beam'. Work out what this was, it will be used to + // update 'adaptive_beam'. + BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; + beam_used_this_frame = tok->tot_cost - best_cost_this_frame; + } + if (num_toks_processed <= config_.min_active) { + // num-toks active is dangerously low, increase the beam even if it + // already exceeds the user-specified beam. + adaptive_beam_ = std::max( + config_.beam, beam_used_this_frame + 2.0 * config_.beam_delta); + } else { + // have adaptive_beam_ approach beam_ in intervals of config_.beam_delta + BaseFloat diff_from_beam = beam_used_this_frame - config_.beam; + if (std::abs(diff_from_beam) < config_.beam_delta) { + adaptive_beam_ = config_.beam; + } else { + // make it close to beam_ + adaptive_beam_ = beam_used_this_frame - + config_.beam_delta * (diff_from_beam > 0 ? 1 : -1); + } + } + } +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { + int32 cur_frame = active_toks_.size() - 1; + StateIdToTokenMap &cur_toks = next_toks_; + + cur_queue_.Clear(); + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) + cur_queue_.Push(tok); + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); + + Token *tok = NULL; + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff && + num_toks_processed > config_.min_active) { // Don't bother processing + // successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed) { + cur_queue_.Push(new_tok); + } + } + } + } // end of for loop + } // end of while loop + if (!decoding_finalized_) { + // Update cost_offsets_, it equals "- best_cost". + cost_offsets_[cur_frame] = adaptive_beam - cur_cutoff; + // Needn't to update adaptive_beam_, since we still process this frame in + // ProcessForFrame. + } +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-bucketqueue.h b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h new file mode 100644 index 00000000000..094e9765d73 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-bucketqueue.h @@ -0,0 +1,624 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat cost_scale; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + cost_scale(1.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + opts->Register("cost-scale", &cost_scale, "A scale that we multiply the " + "token costs by before intergerizing; a larger value means " + "more buckets and precise."); + + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_queue(false) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_queue(false) { } +}; + +} // namespace decoder + + +template +class BucketQueue { + public: + // Constructor. 'cost_scale' is a scale that we multiply the token costs by + // before intergerizing; a larger value means more buckets. + // 'bucket_offset_' is initialized to "15 * cost_scale_". It is an empirical + // value in case we trigger the re-allocation in normal case, since we do in + // fact normalize costs to be not far from zero on each frame. + BucketQueue(BaseFloat cost_scale = 1.0); + + // Adds Token to the queue; sets the field tok->in_queue to true (it is not + // an error if it was already true). + // If a Token was already in the queue but its cost improves, you should + // just Push it again. It will be added to (possibly) a different bucket, but + // the old entry will remain. We use "tok->in_queue" to decide + // an entry is nonexistent or not. When pop a Token off, the field + // 'tok->in_queue' is set to false. So the old entry in the queue will be + // considered as nonexistent when we try to pop it. + void Push(Token *tok); + + // Removes and returns the next Token 'tok' in the queue, or NULL if there + // were no Tokens left. Sets tok->in_queue to false for the returned Token. + Token* Pop(); + + // Clears all the individual buckets. Sets 'first_nonempty_bucket_index_' to + // the end of buckets_. + void Clear(); + + private: + // Configuration value that is multiplied by tokens' costs before integerizing + // them to determine the bucket index + BaseFloat cost_scale_; + + // buckets_ is a list of Tokens 'tok' for each bucket. + // If tok->in_queue is false, then the item is considered as not + // existing (this is to avoid having to explicitly remove Tokens when their + // costs change). The index into buckets_ is determined as follows: + // bucket_index = std::floor(tok->cost * cost_scale_); + // vec_index = bucket_index - bucket_storage_begin_; + // then access buckets_[vec_index]. + std::vector > buckets_; + + // An offset that determines how we index into the buckets_ vector; + // In the constructor this will be initialized to something like + // "15 * cost_scale_" which will make it unlikely that we have to change this + // value in future if we get a much better Token (this is expensive because it + // involves reallocating 'buckets_'). + int32 bucket_offset_; + + // first_nonempty_bucket_index_ is an integer in the range [0, + // buckets_.size() - 1] which is not larger than the index of the first + // nonempty element of buckets_. + int32 first_nonempty_bucket_index_; + + // Synchronizes with first_nonempty_bucket_index_. + std::vector *first_nonempty_bucket_; + + // If the size of the BucketQueue is larger than "bucket_size_tolerance_", we + // will resize it to "bucket_size_tolerance_" in Clear. A weird long + // BucketQueue might be caused when the min-active was activated and an + // unusually large loglikelihood range was encountered. + size_t bucket_size_tolerance_; +}; + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + using BucketQueue = typename kaldi::BucketQueue; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "next_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "cur_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// This function is called from FinalizeDecoding(), and also from + /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is + /// called. + void ProcessNonemitting(); + + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap cur_toks_; + StateIdToTokenMap next_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + BaseFloat adaptive_beam_; // will be set to beam_ when we start + BucketQueue cur_queue_; // temp variable used in + // ProcessForFrame/ProcessNonemitting + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.cc b/src/decoder/lattice-faster-decoder-combine-hashlist.cc new file mode 100644 index 00000000000..bd45a83a3c9 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.cc @@ -0,0 +1,1129 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); + prev_toks_ = new StateIdToTokenMap(); + prev_toks_->SetSize(1000); + cur_toks_ = new StateIdToTokenMap(); + cur_toks_->SetSize(1000); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_ = new StateIdToTokenMap(); + prev_toks_->SetSize(1000); + cur_toks_ = new StateIdToTokenMap(); + cur_toks_->SetSize(1000); +} + + +template +void LatticeFasterDecoderCombineTpl::DeleteElems( + Elem *list, HashList *toks) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks->Delete(e); + } +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + DeleteElems(cur_toks_->Clear(), cur_toks_); + ClearActiveTokens(); + if (delete_fst_) delete fst_; + delete prev_toks_; + delete cur_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + DeleteElems(prev_toks_->Clear(), prev_toks_); + DeleteElems(cur_toks_->Clear(), cur_toks_); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + //cur_toks_[start_state] = start_tok; // initialize current tokens map + cur_toks_->Insert(start_state, start_tok); + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (const Elem *e = cur_toks_->GetList(); e != NULL; e = e->tail) { + Token *tok = e->val; + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + Elem *e_found = token_map->Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + token_map->Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in unordered_map "cur_toks_". + const Elem *final_toks = cur_toks_->GetList(); + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + DeleteElems(prev_toks_->Clear(), prev_toks_); + DeleteElems(cur_toks_->Clear(), cur_toks_); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (const Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = const_cast(e); + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (const Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = const_cast(e); + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(5) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::PossiblyResizeHash( + size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > cur_toks_->Size()) { + cur_toks_->SetSize(new_sz); + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + StateIdToTokenMap *tmp = prev_toks_; + prev_toks_ = cur_toks_; + cur_toks_ = tmp; + DeleteElems(cur_toks_->Clear(), cur_toks_); + + if (prev_toks_->GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(prev_toks_->GetList(), &tok_cnt, + &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_elem) { + StateId state = best_elem->key; + Token *best_tok = best_elem->val; + cost_offset = - best_tok->tot_cost; + for(fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (const Elem *e = prev_toks_->GetList(); e != NULL; e = e->tail) { + cur_queue_.push(e->key); + e->val->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + //KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + //Token *tok = prev_toks_[state]; + Token *tok = prev_toks_->Find(state)->val; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + if (token_orig_cost) { // Build the elements which are used to recover + for (const Elem *e = cur_toks_->GetList(); e != NULL; e = e->tail) { + (*token_orig_cost)[e->val] = e->val->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks = cur_toks_; + + int32 frame = active_toks_.size() - 1; + // Build the queue to process non-emitting arcs. + for (const Elem *e = tmp_toks->GetList(); e != NULL; e = e->tail) { + if (fst_->NumInputEpsilons(e->key) != 0) { + cur_queue_.push(e->key); + e->val->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(tmp_toks->GetList(), NULL, &adaptive_beam, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + Token *tok = tmp_toks->Find(state)->val; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-hashlist.h b/src/decoder/lattice-faster-decoder-combine-hashlist.h new file mode 100644 index 00000000000..ca67cf4c531 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-hashlist.h @@ -0,0 +1,567 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + in_current_queue(false) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer), in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = HashList; + using Elem = typename HashList::Elem; + //using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + //using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap *prev_toks_; + StateIdToTokenMap *cur_toks_; + + void PossiblyResizeHash(size_t num_toks); + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem); + + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void DeleteElems(Elem *list, HashList *toks); + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine-heap.h b/src/decoder/lattice-faster-decoder-combine-heap.h new file mode 100644 index 00000000000..48719aa347c --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-heap.h @@ -0,0 +1,697 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current heap or not(-1). Point out the position + // in current heap so that fix the heap after updating the cost of an existing + // token is more convience and faster + size_t position_in_heap; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), position_in_heap(-1) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current heap or not(-1). Point out the position + // in current heap so that fix the heap after updating the cost of an existing + // token is more convience and faster + size_t position_in_heap; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + position_in_heap(-1) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // It is a minimum heap (A[parent(i)] <= A[i]) since the lower cost the better + // in decoding. + // The index of node of the heap is zero-beased. + // Given the parent is i, the left child is 2*i+1 and the right child is + // 2*i + 2 + // Given the child is i, the parent is "( (i+1) / 2) - 1" + struct TokenHeap { + std::vector elements; + + inline void Siftup(size_t child_index, size_t length) { + while(true) { + if (child_index == 0) break; // it doesn't have parent node + KALDI_ASSERT(child_index < length); + size_t parent_index = (child_index + 1) / 2 - 1; + if (elements[child_index]->tot_cost < elements[parent_index]->tot_cost) { + // Update the index of token + elements[parent_index]->position_in_heap = child_index; + elements[child_index]->position_in_heap = parent_index; + // Swap + std::swap(elements[parent_index], elements[child_index]); + // Update child_index for next turn + child_index = parent_index; + } else { + break; // finish + } + } + } + + inline void Siftdown(size_t parent_index, size_t length) { + while(true) { + if (parent_index >= elements.size() / 2) break; + // Prepare indexes + size_t left_child_index = parent_index * 2 + 1; + size_t right_child_index = parent_index * 2 + 2; + size_t largest = parent_index; + // Get the largest index + if (left_child_index < length && + elements[left_child_index]->tot_cost < elements[largest]->tot_cost) { + largest = left_child_index; + } + if (right_child_index < length && + elements[right_child_index]->tot_cost < elements[largest]->tot_cost) { + largest = right_child_index; + } + // Swap + if (largest != parent_index) { + // Update the index of token + elements[largest]->position_in_heap = parent_index; + elements[parent_index]->position_in_heap = largest; + // Swap + std::swap(elements[parent_index], elements[largest]); + // Update parent_index for next turn + parent_index = largest; + } else { + break; // finish + } + } + } + + inline bool Empty() { + return elements.empty(); + } + + inline Token* Top() { + KALDI_ASSERT(!elements.empty()); + return elements[0]; + } + + inline void Pop() { + // Set the position + elements[0]->position_in_heap = -1; + + // Swap with the last element of the heap + std::swap(elements[0], elements[elements.size() - 1]); + elements[0]->position_in_heap = 0; + + // Delete it from heap + elements.erase(elements.end() - 1); + + // Tune the position from top to down + Siftdown(0, elements.size()); + } + + // Push a new token into the heap + inline void Push(Token* tok) { + KALDI_ASSERT(tok->position_in_heap == -1); // not in heap + // Push to the end of the heap + tok->position_in_heap = elements.size(); + elements.push_back(tok); + // Tune + Siftup(tok->position_in_heap, elements.size()); + } + + inline void Clear() { + while (!Empty()) + Pop(); + } + + // Build Heap. The complexity of the function is O(n) rather than O(nlogn), + // as the series convergence. + inline Token* BuildTokenHeap(const TokenList &token_list, size_t num) { + KALDI_ASSERT(elements.empty()); + elements.reserve(num * 1.5); + // Add elements + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + tok->position_in_heap = elements.size(); + elements.push_back(tok); + } + // Sort with Siftdown + size_t start = elements.size() / 2 - 1; // start is the index of the last + // parent node + size_t length = elements.size(); + for (size_t i = start; i >= 0; i--) { + Siftdown(i, length); + } + } + + TokenHeap() { + elements.resize(0); + } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const TokenList &token_list, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + TokenHeap cur_heap_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine-iterlist.cc b/src/decoder/lattice-faster-decoder-combine-iterlist.cc new file mode 100644 index 00000000000..5c87d72fe14 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-iterlist.cc @@ -0,0 +1,1111 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; + //prev_toks_.clear(); + //cur_toks_.clear(); +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + prev_toks_.clear(); + cur_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const TokenList &token_list, BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token) { + // positive == high cost == bad. + // best_weight is the minimum value. + BaseFloat best_weight = std::numeric_limits::infinity(); + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = tok->state_id; + *best_token = tok; + } + } + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) { + BaseFloat w = static_cast(tok->tot_cost); + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = tok->state_id; + *best_token = tok; + } + } + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + prev_toks_.swap(cur_toks_); + cur_toks_.clear(); + if (prev_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + BaseFloat adaptive_beam; + Token *best_tok = NULL; + StateId best_tok_state_id; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, + &best_tok_state_id, &best_tok); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + Token *tok = prev_toks_[state]; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + << " is " << prev_toks_.size(); +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + int32 frame = active_toks_.size() - 1; + if (token_orig_cost) { // Build the elements which are used to recover + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + (*token_orig_cost)[tok] = tok->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + + // Build the queue to process non-emitting arcs. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (fst_->NumInputEpsilons(tok->state_id) != 0) { + cur_queue_.push(tok->state_id); + tok->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(active_toks_[frame], &adaptive_beam, NULL, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); + Token *tok = (*tmp_toks)[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop + if (token_orig_cost) delete tmp_toks; +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-iterlist.h b/src/decoder/lattice-faster-decoder-combine-iterlist.h new file mode 100644 index 00000000000..900d03520e4 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-iterlist.h @@ -0,0 +1,574 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_current_queue(false) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const TokenList &token_list, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine-itermap.cc b/src/decoder/lattice-faster-decoder-combine-itermap.cc new file mode 100644 index 00000000000..6c9d70bb9b3 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-itermap.cc @@ -0,0 +1,1100 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { + config.Check(); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { + config.Check(); + prev_toks_.reserve(1000); + cur_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + prev_toks_.clear(); + cur_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + std::unordered_map token_orig_cost; + if (!decoding_finalized_) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(&token_orig_cost); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + if (!decoding_finalized_) { // recover last token list + RecoverLastTokenList(token_orig_cost); + } + return (ofst->NumStates() > 0); +} + + +// When GetRawLattice() is called during decoding, the +// active_toks_[last_frame] is changed. To keep the consistency of function +// ProcessForFrame(), recover it. +// Notice: as new token will be added to the head of TokenList, tok->next +// will not be affacted. +template +void LatticeFasterDecoderCombineTpl::RecoverLastTokenList( + const std::unordered_map &token_orig_cost) { + if (!token_orig_cost.empty()) { + for (Token* tok = active_toks_[active_toks_.size() - 1].toks; + tok != NULL;) { + if (token_orig_cost.find(tok) != token_orig_cost.end()) { + DeleteForwardLinks(tok); + tok->tot_cost = token_orig_cost.find(tok)->second; + tok->in_current_queue = false; + tok = tok->next; + } else { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + } +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in unordered_map "cur_toks_". + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + StateId state = iter->first; + Token *tok = iter->second; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(NULL); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. +template +BaseFloat LatticeFasterDecoderCombineTpl::GetCutoff( + const StateIdToTokenMap &toks, BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token) { + // positive == high cost == bad. + // best_weight is the minimum value. + BaseFloat best_weight = std::numeric_limits::infinity(); + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; + } + } + } + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (IterType iter = toks.begin(); iter != toks.end(); iter++) { + BaseFloat w = static_cast(iter->second->tot_cost); + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_token) { + *best_state_id = iter->first; + *best_token = iter->second; + } + } + } + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of emitting tokens on frame " + << NumFramesDecoded() - 1 << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + prev_toks_.swap(cur_toks_); + cur_toks_.clear(); + if (prev_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << frame; + warned_ = true; + } + } + + BaseFloat adaptive_beam; + Token *best_tok = NULL; + StateId best_tok_state_id; + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat cur_cutoff = GetCutoff(prev_toks_, &adaptive_beam, + &best_tok_state_id, &best_tok); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + + // pruning "online" before having seen all tokens + + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = 0.0; + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + // Notice: As the difference between the combine version and the traditional + // version, this "best_tok" is choosen from emittion tokens. Normally, the + // best token of one frame comes from an epsilon non-emittion. So the best + // token is a looser boundary. We use it to estimate a bound on the next + // cutoff and we will update the "next_cutoff" once we have better tokens. + // The "next_cutoff" will be updated in further processing. + if (best_tok) { + cost_offset = - best_tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, best_tok_state_id); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + // ac_cost + graph_cost + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // Build a queue which contains the emittion tokens from previous frame. + for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) { + cur_queue_.push(iter->first); + iter->second->in_current_queue = true; + } + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end()); + Token *tok = prev_toks_[state]; + + BaseFloat cur_cost = tok->tot_cost; + tok->in_current_queue = false; // out of queue + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, &prev_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for emitting + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &cur_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1 + << " is " << prev_toks_.size(); +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting( + std::unordered_map *token_orig_cost) { + if (token_orig_cost) { // Build the elements which are used to recover + for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) { + (*token_orig_cost)[iter->second] = iter->second->tot_cost; + } + } + + StateIdToTokenMap *tmp_toks; + if (token_orig_cost) { // "token_orig_cost" isn't NULL. It means we need to + // recover active_toks_[last_frame] and "cur_toks_" + // will be used in the future. + tmp_toks = new StateIdToTokenMap(cur_toks_); + } else { + tmp_toks = &cur_toks_; + } + + int32 frame = active_toks_.size() - 1; + // Build the queue to process non-emitting arcs. + for (IterType iter = tmp_toks->begin(); iter != tmp_toks->end(); iter++) { + if (fst_->NumInputEpsilons(iter->first) != 0) { + cur_queue_.push(iter->first); + iter->second->in_current_queue = true; + } + } + + // "cur_cutoff" is used to constrain the epsilon emittion in current frame. + // It will not be updated. + BaseFloat adaptive_beam; + BaseFloat cur_cutoff = GetCutoff(*tmp_toks, &adaptive_beam, NULL, NULL); + + while (!cur_queue_.empty()) { + StateId state = cur_queue_.front(); + cur_queue_.pop(); + + KALDI_ASSERT(tmp_toks->find(state) != tmp_toks->end()); + Token *tok = (*tmp_toks)[state]; + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cur_cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost, + tok, tmp_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed && !new_tok->in_current_queue) { + cur_queue_.push(arc.nextstate); + new_tok->in_current_queue = true; + } + } + } + } // end of for loop + tok->in_current_queue = false; + } // end of while loop + if (token_orig_cost) delete tmp_toks; +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::StdToken >; +template class LatticeFasterDecoderCombineTpl; + +template class LatticeFasterDecoderCombineTpl , decodercombine::BackpointerToken>; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl, decodercombine::BackpointerToken >; +template class LatticeFasterDecoderCombineTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine-itermap.h b/src/decoder/lattice-faster-decoder-combine-itermap.h new file mode 100644 index 00000000000..c0b76e4126d --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine-itermap.h @@ -0,0 +1,561 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + in_current_queue(false) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_current_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer), in_current_queue(false) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "cur_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. Then recover it to ensure the + /// consistency of ProcessForFrame(). + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "prev_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// Calls this function once when all frames were processed. + /// Or calls it in GetRawLattice() to generate the complete token list for + /// the last frame. [Deal With the tokens in map "cur_toks_" which would + /// only contains emittion tokens from previous frame.] + /// If the map, "token_orig_cost", isn't NULL, we build the map which will + /// be used to recover "active_toks_[last_frame]" token list for the last + /// frame. + void ProcessNonemitting(std::unordered_map *token_orig_cost); + + /// When GetRawLattice() is called during decoding, the + /// active_toks_[last_frame] is changed. To keep the consistency of function + /// ProcessForFrame(), recover it. + /// Notice: as new token will be added to the head of TokenList, tok->next + /// will not be affacted. + /// "token_orig_cost" is a mapping from token pointer to the tot_cost of the + /// token before propagating non-emitting arcs. It is used to recover the + /// change of original tokens in the last frame and remove the new tokens + /// which come from propagating non-emitting arcs, so that we can guarantee + /// the consistency of function ProcessForFrame(). + void RecoverLastTokenList( + const std::unordered_map &token_orig_cost); + + + /// The "prev_toks_" and "cur_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap prev_toks_; + StateIdToTokenMap cur_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + BaseFloat GetCutoff(const StateIdToTokenMap& toks, + BaseFloat *adaptive_beam, + StateId *best_state_id, Token **best_token); + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::queue cur_queue_; // temp variable used in ProcessForFrame + // and ProcessNonemitting + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-faster-decoder-combine.cc b/src/decoder/lattice-faster-decoder-combine.cc new file mode 100644 index 00000000000..f30fc36b872 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine.cc @@ -0,0 +1,1088 @@ +// decoder/lattice-faster-decoder-combine.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder-combine.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +template +BucketQueue::BucketQueue(BaseFloat cost_scale) : + cost_scale_(cost_scale) { + // NOTE: we reserve plenty of elements to avoid expensive reallocations + // later on. Normally, the size is a little bigger than (adaptive_beam + + // 15) * cost_scale. + int32 bucket_size = (15 + 20) * cost_scale_; + buckets_.resize(bucket_size); + bucket_offset_ = 15 * cost_scale_; + first_nonempty_bucket_index_ = bucket_size - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + bucket_size_tolerance_ = 1.2 * bucket_size; +} + +template +void BucketQueue::Push(Token *tok) { + size_t bucket_index = std::floor(tok->tot_cost * cost_scale_) + + bucket_offset_; + if (bucket_index >= buckets_.size()) { + int32 margin = 10; // a margin which is used to reduce re-allocate + // space frequently + if (static_cast(bucket_index) > 0) { + buckets_.resize(bucket_index + margin); + } else { // less than 0 + int32 increase_size = - static_cast(bucket_index) + margin; + buckets_.resize(buckets_.size() + increase_size); + // translation + for (size_t i = buckets_.size() - 1; i >= increase_size; i--) { + buckets_[i].swap(buckets_[i - increase_size]); + } + bucket_offset_ = bucket_offset_ + increase_size; + bucket_index += increase_size; + first_nonempty_bucket_index_ = bucket_index; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } + tok->in_queue = true; + buckets_[bucket_index].push_back(tok); + if (bucket_index < first_nonempty_bucket_index_) { + first_nonempty_bucket_index_ = bucket_index; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + } +} + +template +Token* BucketQueue::Pop() { + while (true) { + if (!first_nonempty_bucket_->empty()) { + Token *ans = first_nonempty_bucket_->back(); + first_nonempty_bucket_->pop_back(); + if (ans->in_queue) { // If ans->in_queue is false, this means it is a + // duplicate instance of this Token that was left + // over when a Token's best_cost changed, and the + // Token has already been processed(so conceptually, + // it is not in the queue). + ans->in_queue = false; + return ans; + } + } + if (first_nonempty_bucket_->empty()) { + for (; first_nonempty_bucket_index_ + 1 < buckets_.size(); + first_nonempty_bucket_index_++) { + if (!buckets_[first_nonempty_bucket_index_].empty()) break; + } + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; + if (first_nonempty_bucket_->empty()) return NULL; + } + } +} + +template +void BucketQueue::Clear() { + for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) { + buckets_[i].clear(); + } + if (buckets_.size() > bucket_size_tolerance_) { + buckets_.resize(bucket_size_tolerance_); + bucket_offset_ = 15 * cost_scale_; + } + first_nonempty_bucket_index_ = buckets_.size() - 1; + first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_]; +} + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const FST &fst, + const LatticeFasterDecoderCombineConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), + cur_queue_(config_.cost_scale) { + config.Check(); + cur_toks_.reserve(1000); + next_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::LatticeFasterDecoderCombineTpl( + const LatticeFasterDecoderCombineConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0), + cur_queue_(config_.cost_scale) { + config.Check(); + cur_toks_.reserve(1000); + next_toks_.reserve(1000); +} + + +template +LatticeFasterDecoderCombineTpl::~LatticeFasterDecoderCombineTpl() { + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderCombineTpl::InitDecoding() { + // clean up from last time: + cur_toks_.clear(); + next_toks_.clear(); + cost_offsets_.clear(); + ClearActiveTokens(); + + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, start_state, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + next_toks_[start_state] = start_tok; // initialize current tokens map + num_toks_++; + adaptive_beam_ = config_.beam; + cost_offsets_.resize(1); + cost_offsets_[0] = 0.0; + +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderCombineTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + ProcessForFrame(decodable); + } + // A complete token list of the last frame will be generated in FinalizeDecoding() + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderCombineTpl::GetBestPath( + Lattice *olat, + bool use_final_probs) { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderCombineTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + if (!decoding_finalized_ && use_final_probs) { + // Process the non-emitting arcs for the unfinished last frame. + ProcessNonemitting(); + } + + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + + return (ofst->NumStates() > 0); +} + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderCombineTpl::GetLattice( + CompactLattice *ofst, + bool use_final_probs) { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash map "token_map" +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeFasterDecoderCombineTpl::FindOrAddToken( + StateId state, int32 token_list_index, BaseFloat tot_cost, + Token *backpointer, StateIdToTokenMap *token_map, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(token_list_index < active_toks_.size()); + Token *&toks = active_toks_[token_list_index].toks; + typename StateIdToTokenMap::iterator e_found = token_map->find(state); + if (e_found == token_map->end()) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, state, + NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + // insert into the map + (*token_map)[state] = new_tok; + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderCombineTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderCombineTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderCombineTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderCombineTpl::PruneActiveTokens( + BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderCombineTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + // The final tokens are recorded in active_toks_[last_frame] + for (Token *tok = active_toks_[active_toks_.size() - 1].toks; tok != NULL; + tok = tok->next) { + StateId state = tok->state_id; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderCombineTpl::AdvanceDecoding( + DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderCombineTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + ProcessForFrame(decodable); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderCombineTpl::FinalizeDecoding() { + ProcessNonemitting(); + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessForFrame( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 cur_frame = active_toks_.size() - 1, // frame is the frame-index (zero- + // based) used to get likelihoods + // from the decodable object. + next_frame = cur_frame + 1; + + active_toks_.resize(active_toks_.size() + 1); + + cur_toks_.swap(next_toks_); + next_toks_.clear(); + if (cur_toks_.empty()) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens on frame " << cur_frame; + warned_ = true; + } + } + + cur_queue_.Clear(); + // Add tokens to queue + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) + cur_queue_.Push(tok); + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); + // "next_cutoff" is used to limit a new token in next frame should be handle + // or not. It will be updated along with the further processing. + // this will be kept updated to the best-seen-so-far token "on next frame" + // + adaptive_beam + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // "cost_offset" contains the acoustic log-likelihoods on current frame in + // order to keep everything in a nice dynamic range. Reduce roundoff errors. + BaseFloat cost_offset = cost_offsets_[cur_frame]; + + // Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst. + Token *tok = NULL; + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff && + num_toks_processed > config_.min_active) { // Don't bother processing + // successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks_, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed) { + cur_queue_.Push(new_tok); + } + } + } else { // propagate emitting + BaseFloat graph_cost = arc.weight.Value(), + ac_cost = cost_offset - decodable->LogLikelihood(cur_frame, + arc.ilabel), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) { + next_cutoff = tot_cost + adaptive_beam; // a tighter boundary for + // emitting + } + + // no change flag is needed + Token *next_tok = FindOrAddToken(arc.nextstate, next_frame, tot_cost, + tok, &next_toks_, NULL); + // Add ForwardLink from tok to next_tok. Put it on the head of tok->link + // list + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } // end of while loop + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + // Set the cost_offset_ for next frame, it equals "- best_cost_on_next_frame". + cost_offsets_.resize(cur_frame + 2, 0.0); + cost_offsets_[next_frame] = adaptive_beam - next_cutoff; + + { // This block updates adaptive_beam_ + BaseFloat beam_used_this_frame = adaptive_beam; + Token *tok = cur_queue_.Pop(); + if (tok != NULL) { + // We hit the max-active contraint, meaning we effectively pruned to a + // beam tighter than 'beam'. Work out what this was, it will be used to + // update 'adaptive_beam'. + BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam; + beam_used_this_frame = tok->tot_cost - best_cost_this_frame; + } + if (num_toks_processed <= config_.min_active) { + // num-toks active is dangerously low, increase the beam even if it + // already exceeds the user-specified beam. + adaptive_beam_ = std::max( + config_.beam, beam_used_this_frame + 2.0 * config_.beam_delta); + } else { + // have adaptive_beam_ approach beam_ in intervals of config_.beam_delta + BaseFloat diff_from_beam = beam_used_this_frame - config_.beam; + if (std::abs(diff_from_beam) < config_.beam_delta) { + adaptive_beam_ = config_.beam; + } else { + // make it close to beam_ + adaptive_beam_ = beam_used_this_frame - + config_.beam_delta * (diff_from_beam > 0 ? 1 : -1); + } + } + } +} + + +template +void LatticeFasterDecoderCombineTpl::ProcessNonemitting() { + int32 cur_frame = active_toks_.size() - 1; + StateIdToTokenMap &cur_toks = next_toks_; + + cur_queue_.Clear(); + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) + cur_queue_.Push(tok); + + // Declare a local variable so the compiler can put it in a register, since + // C++ assumes other threads could be modifying class members. + BaseFloat adaptive_beam = adaptive_beam_; + // "cur_cutoff" will be kept to the best-seen-so-far token on this frame + // + adaptive_beam + BaseFloat cur_cutoff = std::numeric_limits::infinity(); + + Token *tok = NULL; + int32 num_toks_processed = 0; + int32 max_active = config_.max_active; + + for (; num_toks_processed < max_active && (tok = cur_queue_.Pop()) != NULL; + num_toks_processed++) { + BaseFloat cur_cost = tok->tot_cost; + StateId state = tok->state_id; + if (cur_cost > cur_cutoff && + num_toks_processed > config_.min_active) { // Don't bother processing + // successors. + break; // This is a priority queue. The following tokens will be worse + } else if (cur_cost + adaptive_beam < cur_cutoff) { + cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary + } + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + DeleteForwardLinks(tok); // necessary when re-visiting + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + bool changed; + if (arc.ilabel == 0) { // propagate nonemitting + BaseFloat graph_cost = arc.weight.Value(); + BaseFloat tot_cost = cur_cost + graph_cost; + if (tot_cost < cur_cutoff) { + Token *new_tok = FindOrAddToken(arc.nextstate, cur_frame, tot_cost, + tok, &cur_toks, &changed); + + // Add ForwardLink from tok to new_tok. Put it on the head of + // tok->link list + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new. + if (changed) { + cur_queue_.Push(new_tok); + } + } + } + } // end of for loop + } // end of while loop + if (!decoding_finalized_) { + // Update cost_offsets_, it equals "- best_cost". + cost_offsets_[cur_frame] = adaptive_beam - cur_cutoff; + // Needn't to update adaptive_beam_, since we still process this frame in + // ProcessForFrame. + } +} + + + +// static inline +template +void LatticeFasterDecoderCombineTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderCombineTpl::ClearActiveTokens() { + // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderCombineTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::StdToken > >; +template class LatticeFasterDecoderCombineTpl >; + +template class LatticeFasterDecoderCombineTpl , + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl, + decodercombine::BackpointerToken > >; +template class LatticeFasterDecoderCombineTpl >; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder-combine.h b/src/decoder/lattice-faster-decoder-combine.h new file mode 100644 index 00000000000..094e9765d73 --- /dev/null +++ b/src/decoder/lattice-faster-decoder-combine.h @@ -0,0 +1,624 @@ +// decoder/lattice-faster-decoder-combine.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2019 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// 2019 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_COMBINE_H_ + + +#include "util/stl-utils.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeFasterDecoderCombineConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat cost_scale; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderCombineConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + cost_scale(1.0), + prune_scale(0.1) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + opts->Register("cost-scale", &cost_scale, "A scale that we multiply the " + "token costs by before intergerizing; a larger value means " + "more buckets and precise."); + + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + + +namespace decodercombine { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + +template +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + using StateId = typename Fst::Arc::StateId; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, StateId state_id, + ForwardLinkT *links, Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), in_queue(false) { } +}; + +template +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + using StateId = typename Fst::Arc::StateId; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // Record the state id of the token + StateId state_id; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + // identitfy the token is in current queue or not to prevent duplication in + // function ProcessOneFrame(). + bool in_queue; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, + StateId state_id, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), state_id(state_id), + links(links), next(next), backpointer(backpointer), + in_queue(false) { } +}; + +} // namespace decoder + + +template +class BucketQueue { + public: + // Constructor. 'cost_scale' is a scale that we multiply the token costs by + // before intergerizing; a larger value means more buckets. + // 'bucket_offset_' is initialized to "15 * cost_scale_". It is an empirical + // value in case we trigger the re-allocation in normal case, since we do in + // fact normalize costs to be not far from zero on each frame. + BucketQueue(BaseFloat cost_scale = 1.0); + + // Adds Token to the queue; sets the field tok->in_queue to true (it is not + // an error if it was already true). + // If a Token was already in the queue but its cost improves, you should + // just Push it again. It will be added to (possibly) a different bucket, but + // the old entry will remain. We use "tok->in_queue" to decide + // an entry is nonexistent or not. When pop a Token off, the field + // 'tok->in_queue' is set to false. So the old entry in the queue will be + // considered as nonexistent when we try to pop it. + void Push(Token *tok); + + // Removes and returns the next Token 'tok' in the queue, or NULL if there + // were no Tokens left. Sets tok->in_queue to false for the returned Token. + Token* Pop(); + + // Clears all the individual buckets. Sets 'first_nonempty_bucket_index_' to + // the end of buckets_. + void Clear(); + + private: + // Configuration value that is multiplied by tokens' costs before integerizing + // them to determine the bucket index + BaseFloat cost_scale_; + + // buckets_ is a list of Tokens 'tok' for each bucket. + // If tok->in_queue is false, then the item is considered as not + // existing (this is to avoid having to explicitly remove Tokens when their + // costs change). The index into buckets_ is determined as follows: + // bucket_index = std::floor(tok->cost * cost_scale_); + // vec_index = bucket_index - bucket_storage_begin_; + // then access buckets_[vec_index]. + std::vector > buckets_; + + // An offset that determines how we index into the buckets_ vector; + // In the constructor this will be initialized to something like + // "15 * cost_scale_" which will make it unlikely that we have to change this + // value in future if we get a much better Token (this is expensive because it + // involves reallocating 'buckets_'). + int32 bucket_offset_; + + // first_nonempty_bucket_index_ is an integer in the range [0, + // buckets_.size() - 1] which is not larger than the index of the first + // nonempty element of buckets_. + int32 first_nonempty_bucket_index_; + + // Synchronizes with first_nonempty_bucket_index_. + std::vector *first_nonempty_bucket_; + + // If the size of the BucketQueue is larger than "bucket_size_tolerance_", we + // will resize it to "bucket_size_tolerance_" in Clear. A weird long + // BucketQueue might be caused when the min-active was activated and an + // unusually large loglikelihood range was encountered. + size_t bucket_size_tolerance_; +}; + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template > +class LatticeFasterDecoderCombineTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decodercombine::ForwardLink; + + using StateIdToTokenMap = typename std::unordered_map; + //using StateIdToTokenMap = typename std::unordered_map, std::equal_to, + // fst::PoolAllocator > >; + using IterType = typename StateIdToTokenMap::const_iterator; + + using BucketQueue = typename kaldi::BucketQueue; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderCombineTpl(const FST &fst, + const LatticeFasterDecoderCombineConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderCombineTpl(const LatticeFasterDecoderCombineConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderCombineConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderCombineConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderCombineTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true); + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// The function can be called during decoding, it will process non-emitting + /// arcs from "next_toks_" map to get tokens from both non-emitting and + /// emitting arcs for getting raw lattice. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessForFrame(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + // FindOrAddToken either locates a token in hash map "token_map", or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash map and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 token_list_index, + BaseFloat tot_cost, Token *backpointer, + StateIdToTokenMap *token_map, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Processes non-emitting (epsilon) arcs and emitting arcs for one frame + /// together. It takes the emittion tokens in "cur_toks_" from last frame. + /// Generates non-emitting tokens for previous frame and emitting tokens for + /// next frame. + /// Notice: The emitting tokens for the current frame means the token take + /// acoustic scores of the current frame. (i.e. the destnations of emitting + /// arcs.) + void ProcessForFrame(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. + /// This function is called from FinalizeDecoding(), and also from + /// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is + /// called. + void ProcessNonemitting(); + + /// The "cur_toks_" and "next_toks_" actually allow us to maintain current + /// and next frames. They are indexed by StateId. It is indexed by frame-index + /// plus one, where the frame-index is zero-based, as used in decodable object. + /// That is, the emitting probs of frame t are accounted for in tokens at + /// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + /// the graph. + StateIdToTokenMap cur_toks_; + StateIdToTokenMap next_toks_; + + /// Gets the weight cutoff. + /// Notice: In traiditional version, the histogram prunning method is applied + /// on a complete token list on one frame. But, in this version, it is used + /// on a token list which only contains the emittion part. So the max_active + /// and min_active values might be narrowed. + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + // Notice: It will only be added to emitting arcs (i.e. cost_offsets_[t] is + // added to arcs from "frame t" to "frame t+1"). + LatticeFasterDecoderCombineConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + BaseFloat adaptive_beam_; // will be set to beam_ when we start + BucketQueue cur_queue_; // temp variable used in + // ProcessForFrame/ProcessNonemitting + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderCombineTpl); +}; + +typedef LatticeFasterDecoderCombineTpl > LatticeFasterDecoderCombine; + + + +} // end namespace kaldi. + +#endif