Add HLG decoding for streaming CTC models (#731)

This commit is contained in:
Fangjun Kuang
2024-04-03 21:31:42 +08:00
committed by GitHub
parent f8832cb5f2
commit db67e00c77
28 changed files with 668 additions and 82 deletions

View File

@@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
return stream;
}
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results);
decoder_->Decode(std::move(out[0]), &results, ss, n);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
private:
void InitDecoder() {
if (config_.decoding_method == "greedy_search") {
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
int32_t blank_id = 0;
if (sym_.contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}
int32_t blank_id = 0;
if (sym_.contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}
if (!config_.ctc_fst_decoder_config.graph.empty()) {
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
config_.ctc_fst_decoder_config, blank_id);
} else if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
SHERPA_ONNX_LOGE(
"Unsupported decoding method: %s for streaming CTC models",
config_.decoding_method.c_str());
exit(-1);
}
}
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<OnlineCtcDecoderResult> results(1);
results[0] = std::move(s->GetCtcResult());
decoder_->Decode(std::move(out[0]), &results);
decoder_->Decode(std::move(out[0]), &results, &s, 1);
s->SetCtcResult(results[0]);
}