Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -20,9 +20,10 @@ import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
@@ -69,6 +70,59 @@ def get_args():
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is modified_beam_search.
|
||||
It specifies number of active paths to keep during decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
|
||||
Used only when --decoding-method=modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="char",
|
||||
help="""
|
||||
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
|
||||
Valid values are bpe, bpe+char, char.
|
||||
Note: the char here means characters in CJK languages.
|
||||
Used only when --decoding-method=modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--contexts",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The context list, it is a string containing some words/phrases separated
|
||||
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
|
||||
Used only when --decoding-method=modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="""
|
||||
The context score of each token for biasing word/phrase. Used only if
|
||||
--contexts is given.
|
||||
Used only when --decoding-method=modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
@@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
||||
sp = None
|
||||
if "bpe" in args.modeling_unit:
|
||||
assert_file_exists(args.bpe_model)
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
tokens = {}
|
||||
with open(args.tokens, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
toks = line.strip().split()
|
||||
assert len(toks) == 2, len(toks)
|
||||
assert toks[0] not in tokens, f"Duplicate token: {toks} "
|
||||
tokens[toks[0]] = int(toks[1])
|
||||
return sherpa_onnx.encode_contexts(
|
||||
modeling_unit=args.modeling_unit,
|
||||
contexts=contexts,
|
||||
sp=sp,
|
||||
tokens_table=tokens,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.encoder)
|
||||
@@ -132,11 +207,20 @@ def main():
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
max_active_paths=args.max_active_paths,
|
||||
context_score=args.context_score,
|
||||
)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
contexts_list = []
|
||||
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||
if contexts:
|
||||
print(f"Contexts list: {contexts}")
|
||||
contexts_list = encode_contexts(args, contexts)
|
||||
|
||||
|
||||
streams = []
|
||||
total_duration = 0
|
||||
for wave_filename in args.sound_files:
|
||||
@@ -145,7 +229,11 @@ def main():
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
|
||||
s = recognizer.create_stream()
|
||||
if contexts_list:
|
||||
s = recognizer.create_stream(contexts_list=contexts_list)
|
||||
else:
|
||||
s = recognizer.create_stream()
|
||||
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
|
||||
@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("context-score", &context_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
if (config_.decoding_method == "modified_beam_search" &&
|
||||
nullptr != stream->GetContextGraph()) {
|
||||
// r.hyps has only one element.
|
||||
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||
it->second.context_state = stream->GetContextGraph()->Root();
|
||||
}
|
||||
}
|
||||
|
||||
stream->SetResult(r);
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &contexts) const {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(contexts, config_.context_score);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
bool has_context_graph = false;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
if (!has_context_graph && ss[i]->GetContextGraph())
|
||||
has_context_graph = true;
|
||||
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
if (has_context_graph) {
|
||||
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||
} else {
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
|
||||
std::string decoding_method = "greedy_search";
|
||||
// now support modified_beam_search and greedy_search
|
||||
|
||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
||||
// used only for modified_beam_search
|
||||
int32_t max_active_paths = 4;
|
||||
/// used only for modified_beam_search
|
||||
float context_score = 1.5;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
int32_t max_active_paths, float context_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -112,6 +116,10 @@ class OnlineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
// Create a stream with context phrases
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
* Return false otherwise
|
||||
|
||||
@@ -13,8 +13,9 @@ namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config)
|
||||
: feat_extractor_(config) {}
|
||||
explicit Impl(const FeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: feat_extractor_(config), context_graph_(context_graph) {}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
|
||||
|
||||
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
ContextGraphPtr context_graph_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr */)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OnlineStream::~OnlineStream() = default;
|
||||
|
||||
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||
return impl_->GetStates();
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
|
||||
@@ -16,7 +17,8 @@ namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream {
|
||||
public:
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {});
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OnlineStream();
|
||||
|
||||
/**
|
||||
@@ -71,6 +73,13 @@ class OnlineStream {
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
/**
|
||||
* Get the context graph corresponding to this stream.
|
||||
*
|
||||
* @return Return the context graph for this stream.
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
|
||||
OnlineTransducerDecoderResult &&other);
|
||||
};
|
||||
|
||||
class OnlineStream;
|
||||
class OnlineTransducerDecoder {
|
||||
public:
|
||||
virtual ~OnlineTransducerDecoder() = default;
|
||||
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
|
||||
virtual void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||
|
||||
/** Run transducer beam search given the output from the encoder model.
|
||||
*
|
||||
* Note: Currently this interface is for contextual-biasing feature which
|
||||
* needs a ContextGraph owned by the OnlineStream.
|
||||
*
|
||||
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||
* @param ss A list of OnlineStreams.
|
||||
* @param result It is modified in-place.
|
||||
*
|
||||
* @note There is no need to pass encoder_out_length here since for the
|
||||
* online decoding case, each utterance has the same number of frames
|
||||
* and there are no paddings.
|
||||
*/
|
||||
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// used for endpointing. We need to keep decoder_out after reset
|
||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
Decode(std::move(encoder_out), nullptr, result);
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
|
||||
if (new_token != 0) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob =
|
||||
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
@@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
||||
|
||||
private:
|
||||
|
||||
@@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
|
||||
const EndpointConfig &, bool, const std::string &,
|
||||
int32_t>(),
|
||||
const EndpointConfig &, bool, const std::string &, int32_t,
|
||||
float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths"))
|
||||
py::arg("max_active_paths"), py::arg("context_score"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("context_score", &PyClass::context_score)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) {
|
||||
using PyClass = OnlineRecognizer;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizer")
|
||||
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
|
||||
.def("create_stream", &PyClass::CreateStream)
|
||||
.def("create_stream",
|
||||
[](const PyClass &self) { return self.CreateStream(); })
|
||||
.def(
|
||||
"create_stream",
|
||||
[](PyClass &self,
|
||||
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||
return self.CreateStream(contexts_list);
|
||||
},
|
||||
py::arg("contexts_list"))
|
||||
.def("is_ready", &PyClass::IsReady)
|
||||
.def("decode_stream", &PyClass::DecodeStream)
|
||||
.def("decode_streams",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
@@ -39,6 +39,7 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
context_score: float = 1.5,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
@@ -124,13 +125,17 @@ class OnlineRecognizer(object):
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
max_active_paths=max_active_paths,
|
||||
context_score=context_score,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
|
||||
def create_stream(self):
|
||||
return self.recognizer.create_stream()
|
||||
def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
|
||||
if contexts_list is None:
|
||||
return self.recognizer.create_stream()
|
||||
else:
|
||||
return self.recognizer.create_stream(contexts_list)
|
||||
|
||||
def decode_stream(self, s: OnlineStream):
|
||||
self.recognizer.decode_stream(s)
|
||||
|
||||
Reference in New Issue
Block a user