Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
This commit is contained in:
Wei Kang
2023-06-30 16:46:24 +08:00
committed by GitHub
parent b2e0c4c9c2
commit 513dfaa552
10 changed files with 238 additions and 22 deletions

View File

@@ -20,9 +20,10 @@ import argparse
import time import time
import wave import wave
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import List, Tuple
import numpy as np import numpy as np
import sentencepiece as spm
import sherpa_onnx import sherpa_onnx
@@ -69,6 +70,59 @@ def get_args():
help="Valid values are greedy_search and modified_beam_search", help="Valid values are greedy_search and modified_beam_search",
) )
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""Used only when --decoding-method is modified_beam_search.
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="",
help="""
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--context-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument( parser.add_argument(
"sound_files", "sound_files",
type=str, type=str,
@@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
return samples_float32, f.getframerate() return samples_float32, f.getframerate()
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit,
contexts=contexts,
sp=sp,
tokens_table=tokens,
)
def main(): def main():
args = get_args() args = get_args()
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
@@ -132,11 +207,20 @@ def main():
sample_rate=16000, sample_rate=16000,
feature_dim=80, feature_dim=80,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
) )
print("Started!") print("Started!")
start_time = time.time() start_time = time.time()
contexts_list = []
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
streams = [] streams = []
total_duration = 0 total_duration = 0
for wave_filename in args.sound_files: for wave_filename in args.sound_files:
@@ -145,7 +229,11 @@ def main():
duration = len(samples) / sample_rate duration = len(samples) / sample_rate
total_duration += duration total_duration += duration
s = recognizer.create_stream() if contexts_list:
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples) s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)

View File

@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it."); "True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths, po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search."); "beam size used in modified beam search.");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register("decoding-method", &decoding_method, po->Register("decoding-method", &decoding_method,
"decoding method," "decoding method,"
"now support greedy_search and modified_beam_search."); "now support greedy_search and modified_beam_search.");
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", "; os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "decoding_method=\"" << decoding_method << "\")"; os << "decoding_method=\"" << decoding_method << "\")";
return os.str(); return os.str();
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
} }
#endif #endif
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
std::unique_ptr<OnlineStream> CreateStream() const { std::unique_ptr<OnlineStream> CreateStream() const {
auto stream = std::make_unique<OnlineStream>(config_.feat_config); auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetResult(decoder_->GetEmptyResult()); InitOnlineStream(stream.get());
stream->SetStates(model_->GetEncoderInitStates()); return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
return stream; return stream;
} }
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
std::vector<float> features_vec(n * chunk_size * feature_dim); std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n); std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n); std::vector<int64_t> all_processed_frames(n);
bool has_context_graph = false;
for (int32_t i = 0; i != n; ++i) { for (int32_t i = 0; i != n; ++i) {
if (!has_context_graph && ss[i]->GetContextGraph())
has_context_graph = true;
const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features = std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size); ss[i]->GetFrames(num_processed_frames, chunk_size);
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
auto pair = model_->RunEncoder(std::move(x), std::move(states), auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames)); std::move(processed_frames));
decoder_->Decode(std::move(pair.first), &results); if (has_context_graph) {
decoder_->Decode(std::move(pair.first), ss, &results);
} else {
decoder_->Decode(std::move(pair.first), &results);
}
std::vector<std::vector<Ort::Value>> next_states = std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second); model_->UnStackStates(pair.second);
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
return impl_->CreateStream(); return impl_->CreateStream();
} }
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const { bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s); return impl_->IsReady(s);
} }

View File

@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
std::string decoding_method = "greedy_search"; std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search // now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search // used only for modified_beam_search
int32_t max_active_paths = 4;
/// used only for modified_beam_search
float context_score = 1.5;
OnlineRecognizerConfig() = default; OnlineRecognizerConfig() = default;
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
const EndpointConfig &endpoint_config, const EndpointConfig &endpoint_config,
bool enable_endpoint, bool enable_endpoint,
const std::string &decoding_method, const std::string &decoding_method,
int32_t max_active_paths) int32_t max_active_paths, float context_score)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
endpoint_config(endpoint_config), endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint), enable_endpoint(enable_endpoint),
decoding_method(decoding_method), decoding_method(decoding_method),
max_active_paths(max_active_paths) {} max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;
@@ -112,6 +116,10 @@ class OnlineRecognizer {
/// Create a stream for decoding. /// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const; std::unique_ptr<OnlineStream> CreateStream() const;
// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** /**
* Return true if the given stream has enough frames for decoding. * Return true if the given stream has enough frames for decoding.
* Return false otherwise * Return false otherwise

View File

@@ -13,8 +13,9 @@ namespace sherpa_onnx {
class OnlineStream::Impl { class OnlineStream::Impl {
public: public:
explicit Impl(const FeatureExtractorConfig &config) explicit Impl(const FeatureExtractorConfig &config,
: feat_extractor_(config) {} ContextGraphPtr context_graph)
: feat_extractor_(config), context_graph_(context_graph) {}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
std::vector<Ort::Value> &GetStates() { return states_; } std::vector<Ort::Value> &GetStates() { return states_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private: private:
FeatureExtractor feat_extractor_; FeatureExtractor feat_extractor_;
/// For contextual-biasing
ContextGraphPtr context_graph_;
int32_t num_processed_frames_ = 0; // before subsampling int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_; OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_; std::vector<Ort::Value> states_;
}; };
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
: impl_(std::make_unique<Impl>(config)) {} ContextGraphPtr context_graph /*= nullptr */)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OnlineStream::~OnlineStream() = default; OnlineStream::~OnlineStream() = default;
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates(); return impl_->GetStates();
} }
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h"
@@ -16,7 +17,8 @@ namespace sherpa_onnx {
class OnlineStream { class OnlineStream {
public: public:
explicit OnlineStream(const FeatureExtractorConfig &config = {}); explicit OnlineStream(const FeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
~OnlineStream(); ~OnlineStream();
/** /**
@@ -71,6 +73,13 @@ class OnlineStream {
void SetStates(std::vector<Ort::Value> states); void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates(); std::vector<Ort::Value> &GetStates();
/**
* Get the context graph corresponding to this stream.
*
* @return Return the context graph for this stream.
*/
const ContextGraphPtr &GetContextGraph() const;
private: private:
class Impl; class Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;

View File

@@ -9,6 +9,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
OnlineTransducerDecoderResult &&other); OnlineTransducerDecoderResult &&other);
}; };
class OnlineStream;
class OnlineTransducerDecoder { class OnlineTransducerDecoder {
public: public:
virtual ~OnlineTransducerDecoder() = default; virtual ~OnlineTransducerDecoder() = default;
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
virtual void Decode(Ort::Value encoder_out, virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0; std::vector<OnlineTransducerDecoderResult> *result) = 0;
/** Run transducer beam search given the output from the encoder model.
*
* Note: Currently this interface is for contextual-biasing feature which
* needs a ContextGraph owned by the OnlineStream.
*
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
* @param ss A list of OnlineStreams.
* @param result It is modified in-place.
*
* @note There is no need to pass encoder_out_length here since for the
* online decoding case, each utterance has the same number of frames
* and there are no paddings.
*/
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) {
SHERPA_ONNX_LOGE(
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
exit(-1);
}
// used for endpointing. We need to keep decoder_out after reset // used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
}; };

View File

@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
void OnlineTransducerModifiedBeamSearchDecoder::Decode( void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) { std::vector<OnlineTransducerDecoderResult> *result) {
Decode(std::move(encoder_out), nullptr, result);
}
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> encoder_out_shape = std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape(); encoder_out.GetTensorTypeAndShapeInfo().GetShape();
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} }
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize(); int32_t vocab_size = model_->VocabSize();
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index]; Hypothesis new_hyp = prev[hyp_index];
const float prev_lm_log_prob = new_hyp.lm_log_prob; const float prev_lm_log_prob = new_hyp.lm_log_prob;
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) { if (new_token != 0) {
new_hyp.ys.push_back(new_token); new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset); new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0; new_hyp.num_trailing_blanks = 0;
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
if (lm_) { if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp); lm_->ComputeLMScore(lm_scale_, &new_hyp);
} }
} else { } else {
++new_hyp.num_trailing_blanks; ++new_hyp.num_trailing_blanks;
} }
new_hyp.log_prob = new_hyp.log_prob = p_logprob[k] + context_score -
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the prev_lm_log_prob; // log_prob only includes the
// score of the transducer // score of the transducer
hyps.Add(std::move(new_hyp)); hyps.Add(std::move(new_hyp));
} // for (auto k : topk) } // for (auto k : topk)

View File

@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/online-lm.h" #include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out, void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override; std::vector<OnlineTransducerDecoderResult> *result) override;
void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private: private:

View File

@@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig") py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, .def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const OnlineLMConfig &, const OnlineTransducerModelConfig &, const OnlineLMConfig &,
const EndpointConfig &, bool, const std::string &, const EndpointConfig &, bool, const std::string &, int32_t,
int32_t>(), float>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"), py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths")) py::arg("max_active_paths"), py::arg("context_score"))
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }
@@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) {
using PyClass = OnlineRecognizer; using PyClass = OnlineRecognizer;
py::class_<PyClass>(*m, "OnlineRecognizer") py::class_<PyClass>(*m, "OnlineRecognizer")
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", &PyClass::CreateStream) .def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
},
py::arg("contexts_list"))
.def("is_ready", &PyClass::IsReady) .def("is_ready", &PyClass::IsReady)
.def("decode_stream", &PyClass::DecodeStream) .def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams", .def("decode_streams",

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Xiaomi Corporation # Copyright (c) 2023 Xiaomi Corporation
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional
from _sherpa_onnx import ( from _sherpa_onnx import (
EndpointConfig, EndpointConfig,
@@ -39,6 +39,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0, rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
max_active_paths: int = 4, max_active_paths: int = 4,
context_score: float = 1.5,
provider: str = "cpu", provider: str = "cpu",
): ):
""" """
@@ -124,13 +125,17 @@ class OnlineRecognizer(object):
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,
max_active_paths=max_active_paths, max_active_paths=max_active_paths,
context_score=context_score,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config self.config = recognizer_config
def create_stream(self): def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
return self.recognizer.create_stream() if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OnlineStream): def decode_stream(self, s: OnlineStream):
self.recognizer.decode_stream(s) self.recognizer.decode_stream(s)