Skip to content

Commit

Permalink
Support getting word IDs for CTC HLG decoding. (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jun 6, 2024
1 parent 69347ff commit 1a43d1e
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 13 deletions.
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;

/// The decoded word IDs
/// Note: tokens.size() is usually not equal to words.size()
/// words is empty for greedy search decoding.
/// it is not empty when an HLG graph or an HLG graph is used.
std::vector<int32_t> words;

/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
///
/// tokens.size() == timestamps.size()
std::vector<int32_t> timestamps;
};

Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
// -1 here since the input labels are incremented during graph
// construction
r.tokens.push_back(arc.ilabel - 1);
if (arc.olabel != 0) {
r.words.push_back(arc.olabel);
}

r.timestamps.push_back(t);
prev = arc.ilabel;
Expand Down
4 changes: 0 additions & 4 deletions sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode(

if (timestamps.size() == results[i].tokens.size()) {
results[i].timestamps = std::move(timestamps);
} else {
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
static_cast<int32_t>(results[i].tokens.size()),
static_cast<int32_t>(timestamps.size()));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
r.timestamps.push_back(time);
}

r.words = std::move(src.words);

return r;
}

Expand Down
14 changes: 14 additions & 0 deletions sherpa-onnx/csrc/offline-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const {
}
sep = ", ";
}
os << "], ";

sep = "";

os << "\""
<< "words"
<< "\""
<< ": ";
os << "[";
for (int32_t w : words) {
os << sep << w;
sep = ", ";
}

os << "]";
os << "}";

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct OfflineRecognitionResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

std::vector<int32_t> words;

std::string AsJsonString() const;
};

Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/online-ctc-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;

/// The decoded word IDs
/// Note: tokens.size() is usually not equal to words.size()
/// words is empty for greedy search decoding.
/// it is not empty when an HLG graph or an HLG graph is used.
std::vector<int32_t> words;

/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
///
/// tokens.size() == timestamps.size()
std::vector<int32_t> timestamps;

int32_t num_trailing_blanks = 0;
Expand Down
7 changes: 4 additions & 3 deletions sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
bool ok = decoder->GetBestPath(&fst_out);
if (ok) {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out_unused;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
&osymbols_out_unused, nullptr);
std::vector<int32_t> osymbols_out;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
nullptr);
std::vector<int64_t> tokens;
tokens.reserve(isymbols_out.size());

Expand Down Expand Up @@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
}

result->tokens = std::move(tokens);
result->words = std::move(osymbols_out);
result->timestamps = std::move(timestamps);
// no need to set frame_offset
}
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/online-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
}

r.segment = segment;
r.words = std::move(src.words);
r.start_time = frames_since_start * frame_shift_ms / 1000.;

return r;
Expand Down
17 changes: 11 additions & 6 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ namespace sherpa_onnx {
template <typename T>
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
if (precision != 0) {
oss << std::fixed << std::setprecision(precision);
}
oss << "[";
std::string sep = "";
for (const auto &item : vec) {
oss << sep << item;
sep = ", ";
}
oss << " ]";
oss << "]";
return oss.str();
}

Expand All @@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string> &vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
oss << "[";
std::string sep = "";
for (const auto &item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
oss << " ]";
oss << "]";
return oss.str();
}

std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"text\": "
<< "\"" << text << "\""
<< ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"words\": " << VecToString(words, 0) << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct OnlineRecognizerResult {
/// log-domain scores from "hot-phrase" contextual boosting
std::vector<float> context_scores;

std::vector<int32_t> words;

/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/offline-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
})
.def_property_readonly("tokens",
[](const PyClass &self) { return self.tokens; })
.def_property_readonly("words",
[](const PyClass &self) { return self.words; })
.def_property_readonly(
"timestamps", [](const PyClass &self) { return self.timestamps; });
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
})
.def_property_readonly(
"segment", [](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"words",
[](PyClass &self) -> std::vector<int32_t> { return self.words; })
.def_property_readonly(
"is_final", [](PyClass &self) -> bool { return self.is_final; })
.def("__str__", &PyClass::AsJsonString,
Expand Down

0 comments on commit 1a43d1e

Please sign in to comment.