Add HLG decoding for streaming CTC models (#731)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetStates(model_->GetInitStates());
|
||||
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||
|
||||
return stream;
|
||||
}
|
||||
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config, blank_id);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s for streaming CTC models",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user