Add HLG decoding for streaming CTC models (#731)
This commit is contained in:
@@ -51,6 +51,8 @@ set(sources
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
offline-zipformer-ctc-model.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-ctc-fst-decoder-config.cc
|
||||
online-ctc-fst-decoder.cc
|
||||
online-ctc-greedy-search-decoder.cc
|
||||
online-ctc-model.cc
|
||||
online-lm-config.cc
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OfflineCtcFstDecoderConfig::ToString() const {
|
||||
@@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OfflineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace sherpa_onnx {
|
||||
// @param filename Path to a StdVectorFst or StdConstFst graph
|
||||
// @return The caller should free the returned pointer using `delete` to
|
||||
// avoid memory leak.
|
||||
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
// read decoding network FST
|
||||
std::ifstream is(filename, std::ios::binary);
|
||||
if (!is.good()) {
|
||||
|
||||
@@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in fst_decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream;
|
||||
|
||||
struct OnlineCtcDecoderResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
@@ -37,7 +41,13 @@ class OnlineCtcDecoder {
|
||||
* @param results Input & Output parameters..
|
||||
*/
|
||||
virtual void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) = 0;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineCtcFstDecoderConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineCtcFstDecoderConfig(";
|
||||
os << "graph=\"" << graph << "\", ";
|
||||
os << "max_active=" << max_active << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
|
||||
|
||||
po->Register("ctc-max-active", &max_active,
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OnlineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineCtcFstDecoderConfig {
|
||||
// Path to H.fst, HL.fst or HLG.fst
|
||||
std::string graph;
|
||||
int32_t max_active = 3000;
|
||||
|
||||
OnlineCtcFstDecoderConfig() = default;
|
||||
|
||||
OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
|
||||
: graph(graph), max_active(max_active) {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
@@ -0,0 +1,125 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "kaldi-decoder/csrc/decodable-ctc.h"
|
||||
#include "kaldifst/csrc/fstext-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ./offline-ctc-fst-decoder.cc
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);
|
||||
|
||||
OnlineCtcFstDecoder::OnlineCtcFstDecoder(
|
||||
const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
|
||||
: config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
|
||||
options_.max_active = config_.max_active;
|
||||
}
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder>
|
||||
OnlineCtcFstDecoder::CreateFasterDecoder() const {
|
||||
return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_);
|
||||
}
|
||||
|
||||
static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
int32_t num_cols, OnlineCtcDecoderResult *result,
|
||||
OnlineStream *s, int32_t blank_id) {
|
||||
int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
|
||||
kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
|
||||
processed_frames);
|
||||
|
||||
kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder();
|
||||
if (processed_frames == 0) {
|
||||
decoder->InitDecoding();
|
||||
}
|
||||
|
||||
decoder->AdvanceDecoding(&decodable);
|
||||
|
||||
if (decoder->ReachedFinal()) {
|
||||
fst::VectorFst<fst::LatticeArc> fst_out;
|
||||
bool ok = decoder->GetBestPath(&fst_out);
|
||||
if (ok) {
|
||||
std::vector<int32_t> isymbols_out;
|
||||
std::vector<int32_t> osymbols_out_unused;
|
||||
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
|
||||
&osymbols_out_unused, nullptr);
|
||||
std::vector<int64_t> tokens;
|
||||
tokens.reserve(isymbols_out.size());
|
||||
|
||||
std::vector<int32_t> timestamps;
|
||||
timestamps.reserve(isymbols_out.size());
|
||||
|
||||
std::ostringstream os;
|
||||
int32_t prev_id = -1;
|
||||
int32_t num_trailing_blanks = 0;
|
||||
int32_t f = 0; // frame number
|
||||
|
||||
for (auto i : isymbols_out) {
|
||||
i -= 1;
|
||||
|
||||
if (i == blank_id) {
|
||||
num_trailing_blanks += 1;
|
||||
} else {
|
||||
num_trailing_blanks = 0;
|
||||
}
|
||||
|
||||
if (i != blank_id && i != prev_id) {
|
||||
tokens.push_back(i);
|
||||
timestamps.push_back(f);
|
||||
}
|
||||
prev_id = i;
|
||||
f += 1;
|
||||
}
|
||||
|
||||
result->tokens = std::move(tokens);
|
||||
result->timestamps = std::move(timestamps);
|
||||
// no need to set frame_offset
|
||||
}
|
||||
}
|
||||
|
||||
processed_frames += num_rows;
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss, int32_t n) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (log_probs_shape[0] != results->size()) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]),
|
||||
static_cast<int32_t>(results->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (log_probs_shape[0] != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]), n);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
||||
|
||||
const float *p = log_probs.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
||||
&(*results)[i], ss[i], blank_id_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fst.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineCtcFstDecoder : public OnlineCtcDecoder {
|
||||
public:
|
||||
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
||||
int32_t blank_id);
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const override;
|
||||
|
||||
private:
|
||||
OnlineCtcFstDecoderConfig config_;
|
||||
kaldi_decoder::FasterDecoderOptions options_;
|
||||
|
||||
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
int32_t blank_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
@@ -13,7 +13,8 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineCtcGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) override;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
int32_t blank_id_;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetStates(model_->GetInitStates());
|
||||
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||
|
||||
return stream;
|
||||
}
|
||||
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config, blank_id);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s for streaming CTC models",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<typename T>
|
||||
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
template <typename T>
|
||||
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
}
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
template <> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||
<< start_time << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
@@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
model_config.Register(po);
|
||||
endpoint_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
@@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -80,6 +81,7 @@ struct OnlineRecognizerConfig {
|
||||
OnlineModelConfig model_config;
|
||||
OnlineLMConfig lm_config;
|
||||
EndpointConfig endpoint_config;
|
||||
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
bool enable_endpoint = true;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
@@ -96,19 +98,19 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
OnlineRecognizerConfig(
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
endpoint_config(endpoint_config),
|
||||
ctc_fst_decoder_config(ctc_fst_decoder_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
|
||||
@@ -104,6 +104,18 @@ class OnlineStream::Impl {
|
||||
return paraformer_alpha_cache_;
|
||||
}
|
||||
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
faster_decoder_ = std::move(decoder);
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const {
|
||||
return faster_decoder_.get();
|
||||
}
|
||||
|
||||
int32_t &GetFasterDecoderProcessedFrames() {
|
||||
return faster_decoder_processed_frames_;
|
||||
}
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
@@ -121,6 +133,8 @@ class OnlineStream::Impl {
|
||||
std::vector<float> paraformer_encoder_out_cache_;
|
||||
std::vector<float> paraformer_alpha_cache_;
|
||||
OnlineParaformerDecoderResult paraformer_result_;
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_;
|
||||
int32_t faster_decoder_processed_frames_ = 0;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
@@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
void OnlineStream::SetFasterDecoder(
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
impl_->SetFasterDecoder(std::move(decoder));
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const {
|
||||
return impl_->GetFasterDecoder();
|
||||
}
|
||||
|
||||
int32_t &OnlineStream::GetFasterDecoderProcessedFrames() {
|
||||
return impl_->GetFasterDecoderProcessedFrames();
|
||||
}
|
||||
|
||||
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
|
||||
return impl_->GetParaformerFeatCache();
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
@@ -97,6 +98,11 @@ class OnlineStream {
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
// for online ctc decoder
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder);
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const;
|
||||
int32_t &GetFasterDecoderProcessedFrames();
|
||||
|
||||
// for streaming paraformer
|
||||
std::vector<float> &GetParaformerFeatCache();
|
||||
std::vector<float> &GetParaformerEncoderOutCache();
|
||||
|
||||
@@ -18,6 +18,7 @@ set(srcs
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
online-ctc-fst-decoder-config.cc
|
||||
online-lm-config.cc
|
||||
online-model-config.cc
|
||||
online-paraformer-model-config.cc
|
||||
|
||||
23
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
Normal file
23
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineCtcFstDecoderConfig(py::module *m) {
|
||||
using PyClass = OnlineCtcFstDecoderConfig;
|
||||
py::class_<PyClass>(*m, "OnlineCtcFstDecoderConfig")
|
||||
.def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
|
||||
py::arg("max_active") = 3000)
|
||||
.def_readwrite("graph", &PyClass::graph)
|
||||
.def_readwrite("max_active", &PyClass::max_active)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
Normal file
16
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineCtcFstDecoderConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
@@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
"tokens",
|
||||
[](PyClass &self) -> std::vector<std::string> { return self.tokens; })
|
||||
.def_property_readonly(
|
||||
"start_time",
|
||||
[](PyClass &self) -> float { return self.start_time; })
|
||||
"start_time", [](PyClass &self) -> float { return self.start_time; })
|
||||
.def_property_readonly(
|
||||
"timestamps",
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
|
||||
@@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
.def_property_readonly(
|
||||
"lm_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
|
||||
.def_property_readonly("context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
"segment", [](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"segment",
|
||||
[](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"is_final",
|
||||
[](PyClass &self) -> bool { return self.is_final; })
|
||||
"is_final", [](PyClass &self) -> bool { return self.is_final; })
|
||||
.def("as_json_string", &PyClass::AsJsonString,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &, bool,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||
.def(
|
||||
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &,
|
||||
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
|
||||
int32_t, const std::string &, float, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(),
|
||||
py::arg("endpoint_config") = EndpointConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
@@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
m.doc() = "pybind11 binding of sherpa-onnx";
|
||||
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineCtcFstDecoderConfig(&m);
|
||||
PybindOnlineModelConfig(&m);
|
||||
PybindOnlineLMConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
|
||||
@@ -16,6 +16,7 @@ from _sherpa_onnx import (
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineWenetCtcModelConfig,
|
||||
OnlineZipformer2CtcModelConfig,
|
||||
OnlineCtcFstDecoderConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -314,6 +315,8 @@ class OnlineRecognizer(object):
|
||||
rule2_min_trailing_silence: float = 1.2,
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
ctc_graph: str = "",
|
||||
ctc_max_active: int = 3000,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
@@ -355,6 +358,12 @@ class OnlineRecognizer(object):
|
||||
is detected.
|
||||
decoding_method:
|
||||
The only valid value is greedy_search.
|
||||
ctc_graph:
|
||||
If not empty, decoding_method is ignored. It contains the path to
|
||||
H.fst, HL.fst, or HLG.fst
|
||||
ctc_max_active:
|
||||
Used only when ctc_graph is not empty. It specifies the maximum
|
||||
active paths at a time.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
@@ -384,10 +393,16 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
ctc_fst_decoder_config = OnlineCtcFstDecoderConfig(
|
||||
graph=ctc_graph,
|
||||
max_active=ctc_max_active,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
endpoint_config=endpoint_config,
|
||||
ctc_fst_decoder_config=ctc_fst_decoder_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user