Add HLG decoding for streaming CTC models (#731)
This commit is contained in:
@@ -51,6 +51,8 @@ set(sources
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
offline-zipformer-ctc-model.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-ctc-fst-decoder-config.cc
|
||||
online-ctc-fst-decoder.cc
|
||||
online-ctc-greedy-search-decoder.cc
|
||||
online-ctc-model.cc
|
||||
online-lm-config.cc
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OfflineCtcFstDecoderConfig::ToString() const {
|
||||
@@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OfflineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace sherpa_onnx {
|
||||
// @param filename Path to a StdVectorFst or StdConstFst graph
|
||||
// @return The caller should free the returned pointer using `delete` to
|
||||
// avoid memory leak.
|
||||
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
// read decoding network FST
|
||||
std::ifstream is(filename, std::ios::binary);
|
||||
if (!is.good()) {
|
||||
|
||||
@@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in fst_decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream;
|
||||
|
||||
struct OnlineCtcDecoderResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
@@ -37,7 +41,13 @@ class OnlineCtcDecoder {
|
||||
* @param results Input & Output parameters..
|
||||
*/
|
||||
virtual void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) = 0;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineCtcFstDecoderConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineCtcFstDecoderConfig(";
|
||||
os << "graph=\"" << graph << "\", ";
|
||||
os << "max_active=" << max_active << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
|
||||
|
||||
po->Register("ctc-max-active", &max_active,
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OnlineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineCtcFstDecoderConfig {
|
||||
// Path to H.fst, HL.fst or HLG.fst
|
||||
std::string graph;
|
||||
int32_t max_active = 3000;
|
||||
|
||||
OnlineCtcFstDecoderConfig() = default;
|
||||
|
||||
OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
|
||||
: graph(graph), max_active(max_active) {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
@@ -0,0 +1,125 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "kaldi-decoder/csrc/decodable-ctc.h"
|
||||
#include "kaldifst/csrc/fstext-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ./offline-ctc-fst-decoder.cc
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);
|
||||
|
||||
OnlineCtcFstDecoder::OnlineCtcFstDecoder(
|
||||
const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
|
||||
: config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
|
||||
options_.max_active = config_.max_active;
|
||||
}
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder>
|
||||
OnlineCtcFstDecoder::CreateFasterDecoder() const {
|
||||
return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_);
|
||||
}
|
||||
|
||||
static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
int32_t num_cols, OnlineCtcDecoderResult *result,
|
||||
OnlineStream *s, int32_t blank_id) {
|
||||
int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
|
||||
kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
|
||||
processed_frames);
|
||||
|
||||
kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder();
|
||||
if (processed_frames == 0) {
|
||||
decoder->InitDecoding();
|
||||
}
|
||||
|
||||
decoder->AdvanceDecoding(&decodable);
|
||||
|
||||
if (decoder->ReachedFinal()) {
|
||||
fst::VectorFst<fst::LatticeArc> fst_out;
|
||||
bool ok = decoder->GetBestPath(&fst_out);
|
||||
if (ok) {
|
||||
std::vector<int32_t> isymbols_out;
|
||||
std::vector<int32_t> osymbols_out_unused;
|
||||
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
|
||||
&osymbols_out_unused, nullptr);
|
||||
std::vector<int64_t> tokens;
|
||||
tokens.reserve(isymbols_out.size());
|
||||
|
||||
std::vector<int32_t> timestamps;
|
||||
timestamps.reserve(isymbols_out.size());
|
||||
|
||||
std::ostringstream os;
|
||||
int32_t prev_id = -1;
|
||||
int32_t num_trailing_blanks = 0;
|
||||
int32_t f = 0; // frame number
|
||||
|
||||
for (auto i : isymbols_out) {
|
||||
i -= 1;
|
||||
|
||||
if (i == blank_id) {
|
||||
num_trailing_blanks += 1;
|
||||
} else {
|
||||
num_trailing_blanks = 0;
|
||||
}
|
||||
|
||||
if (i != blank_id && i != prev_id) {
|
||||
tokens.push_back(i);
|
||||
timestamps.push_back(f);
|
||||
}
|
||||
prev_id = i;
|
||||
f += 1;
|
||||
}
|
||||
|
||||
result->tokens = std::move(tokens);
|
||||
result->timestamps = std::move(timestamps);
|
||||
// no need to set frame_offset
|
||||
}
|
||||
}
|
||||
|
||||
processed_frames += num_rows;
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss, int32_t n) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (log_probs_shape[0] != results->size()) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]),
|
||||
static_cast<int32_t>(results->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (log_probs_shape[0] != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]), n);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
||||
|
||||
const float *p = log_probs.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
||||
&(*results)[i], ss[i], blank_id_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fst.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineCtcFstDecoder : public OnlineCtcDecoder {
|
||||
public:
|
||||
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
||||
int32_t blank_id);
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const override;
|
||||
|
||||
private:
|
||||
OnlineCtcFstDecoderConfig config_;
|
||||
kaldi_decoder::FasterDecoderOptions options_;
|
||||
|
||||
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
int32_t blank_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
@@ -13,7 +13,8 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineCtcGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) override;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
int32_t blank_id_;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetStates(model_->GetInitStates());
|
||||
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||
|
||||
return stream;
|
||||
}
|
||||
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config, blank_id);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s for streaming CTC models",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<typename T>
|
||||
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
template <typename T>
|
||||
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
}
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
template <> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||
<< start_time << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
@@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
model_config.Register(po);
|
||||
endpoint_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
@@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -80,6 +81,7 @@ struct OnlineRecognizerConfig {
|
||||
OnlineModelConfig model_config;
|
||||
OnlineLMConfig lm_config;
|
||||
EndpointConfig endpoint_config;
|
||||
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
bool enable_endpoint = true;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
@@ -96,19 +98,19 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
OnlineRecognizerConfig(
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
endpoint_config(endpoint_config),
|
||||
ctc_fst_decoder_config(ctc_fst_decoder_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
|
||||
@@ -104,6 +104,18 @@ class OnlineStream::Impl {
|
||||
return paraformer_alpha_cache_;
|
||||
}
|
||||
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
faster_decoder_ = std::move(decoder);
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const {
|
||||
return faster_decoder_.get();
|
||||
}
|
||||
|
||||
int32_t &GetFasterDecoderProcessedFrames() {
|
||||
return faster_decoder_processed_frames_;
|
||||
}
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
@@ -121,6 +133,8 @@ class OnlineStream::Impl {
|
||||
std::vector<float> paraformer_encoder_out_cache_;
|
||||
std::vector<float> paraformer_alpha_cache_;
|
||||
OnlineParaformerDecoderResult paraformer_result_;
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_;
|
||||
int32_t faster_decoder_processed_frames_ = 0;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
@@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
void OnlineStream::SetFasterDecoder(
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
impl_->SetFasterDecoder(std::move(decoder));
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const {
|
||||
return impl_->GetFasterDecoder();
|
||||
}
|
||||
|
||||
int32_t &OnlineStream::GetFasterDecoderProcessedFrames() {
|
||||
return impl_->GetFasterDecoderProcessedFrames();
|
||||
}
|
||||
|
||||
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
|
||||
return impl_->GetParaformerFeatCache();
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
@@ -97,6 +98,11 @@ class OnlineStream {
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
// for online ctc decoder
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder);
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const;
|
||||
int32_t &GetFasterDecoderProcessedFrames();
|
||||
|
||||
// for streaming paraformer
|
||||
std::vector<float> &GetParaformerFeatCache();
|
||||
std::vector<float> &GetParaformerEncoderOutCache();
|
||||
|
||||
Reference in New Issue
Block a user