diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index fff8bf94..44fda24e 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -20,9 +20,10 @@ import argparse import time import wave from pathlib import Path -from typing import Tuple +from typing import List, Tuple import numpy as np +import sentencepiece as spm import sherpa_onnx @@ -69,6 +70,59 @@ def get_args(): help="Valid values are greedy_search and modified_beam_search", ) + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="", + help=""" + Path to bpe.model, it will be used to tokenize contexts biasing phrases. + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="char", + help=""" + The type of modeling unit, it will be used to tokenize contexts biasing phrases. + Valid values are bpe, bpe+char, char. + Note: the char here means characters in CJK languages. + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--contexts", + type=str, + default="", + help=""" + The context list, it is a string containing some words/phrases separated + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--context-score", + type=float, + default=1.5, + help=""" + The context score of each token for biasing word/phrase. Used only if + --contexts is given. + Used only when --decoding-method=modified_beam_search + """, + ) + parser.add_argument( "sound_files", type=str, @@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: return samples_float32, f.getframerate() +def encode_contexts(args, contexts: List[str]) -> List[List[int]]: + sp = None + if "bpe" in args.modeling_unit: + assert_file_exists(args.bpe_model) + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + tokens = {} + with open(args.tokens, "r", encoding="utf-8") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, len(toks) + assert toks[0] not in tokens, f"Duplicate token: {toks} " + tokens[toks[0]] = int(toks[1]) + return sherpa_onnx.encode_contexts( + modeling_unit=args.modeling_unit, + contexts=contexts, + sp=sp, + tokens_table=tokens, + ) + + def main(): args = get_args() assert_file_exists(args.encoder) @@ -132,11 +207,20 @@ def main(): sample_rate=16000, feature_dim=80, decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + context_score=args.context_score, ) print("Started!") start_time = time.time() + contexts_list = [] + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] + if contexts: + print(f"Contexts list: {contexts}") + contexts_list = encode_contexts(args, contexts) + + streams = [] total_duration = 0 for wave_filename in args.sound_files: @@ -145,7 +229,11 @@ def main(): duration = len(samples) / sample_rate total_duration += duration - s = recognizer.create_stream() + if contexts_list: + s = recognizer.create_stream(contexts_list=contexts_list) + else: + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 06aab880..39d3c177 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "True to enable endpoint detection. False to disable it."); po->Register("max-active-paths", &max_active_paths, "beam size used in modified beam search."); + po->Register("context-score", &context_score, + "The bonus score for each token in context word/phrase. " + "Used only when decoding_method is modified_beam_search"); po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); @@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "max_active_paths=" << max_active_paths << ", "; + os << "context_score=" << context_score << ", "; os << "decoding_method=\"" << decoding_method << "\")"; return os.str(); @@ -166,10 +170,37 @@ class OnlineRecognizer::Impl { } #endif + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + + if (config_.decoding_method == "modified_beam_search" && + nullptr != stream->GetContextGraph()) { + // r.hyps has only one element. + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = stream->GetContextGraph()->Root(); + } + } + + stream->SetResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + std::unique_ptr CreateStream() const { auto stream = std::make_unique(config_.feat_config); - stream->SetResult(decoder_->GetEmptyResult()); - stream->SetStates(model_->GetEncoderInitStates()); + InitOnlineStream(stream.get()); + return stream; + } + + std::unique_ptr CreateStream( + const std::vector> &contexts) const { + // We create context_graph at this level, because we might have default + // context_graph(will be added later if needed) that belongs to the whole + // model rather than each stream. + auto context_graph = + std::make_shared(contexts, config_.context_score); + auto stream = + std::make_unique(config_.feat_config, context_graph); + InitOnlineStream(stream.get()); return stream; } @@ -188,8 +219,12 @@ class OnlineRecognizer::Impl { std::vector features_vec(n * chunk_size * feature_dim); std::vector> states_vec(n); std::vector all_processed_frames(n); + bool has_context_graph = false; for (int32_t i = 0; i != n; ++i) { + if (!has_context_graph && ss[i]->GetContextGraph()) + has_context_graph = true; + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = ss[i]->GetFrames(num_processed_frames, chunk_size); @@ -226,7 +261,11 @@ class OnlineRecognizer::Impl { auto pair = model_->RunEncoder(std::move(x), std::move(states), std::move(processed_frames)); - decoder_->Decode(std::move(pair.first), &results); + if (has_context_graph) { + decoder_->Decode(std::move(pair.first), ss, &results); + } else { + decoder_->Decode(std::move(pair.first), &results); + } std::vector> next_states = model_->UnStackStates(pair.second); @@ -297,6 +336,11 @@ std::unique_ptr OnlineRecognizer::CreateStream() const { return impl_->CreateStream(); } +std::unique_ptr OnlineRecognizer::CreateStream( + const std::vector> &context_list) const { + return impl_->CreateStream(context_list); +} + bool OnlineRecognizer::IsReady(OnlineStream *s) const { return impl_->IsReady(s); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 136fb820..bd8321c1 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -75,7 +75,10 @@ struct OnlineRecognizerConfig { std::string decoding_method = "greedy_search"; // now support modified_beam_search and greedy_search - int32_t max_active_paths = 4; // used only for modified_beam_search + // used only for modified_beam_search + int32_t max_active_paths = 4; + /// used only for modified_beam_search + float context_score = 1.5; OnlineRecognizerConfig() = default; @@ -85,13 +88,14 @@ struct OnlineRecognizerConfig { const EndpointConfig &endpoint_config, bool enable_endpoint, const std::string &decoding_method, - int32_t max_active_paths) + int32_t max_active_paths, float context_score) : feat_config(feat_config), model_config(model_config), endpoint_config(endpoint_config), enable_endpoint(enable_endpoint), decoding_method(decoding_method), - max_active_paths(max_active_paths) {} + max_active_paths(max_active_paths), + context_score(context_score) {} void Register(ParseOptions *po); bool Validate() const; @@ -112,6 +116,10 @@ class OnlineRecognizer { /// Create a stream for decoding. std::unique_ptr CreateStream() const; + // Create a stream with context phrases + std::unique_ptr CreateStream( + const std::vector> &context_list) const; + /** * Return true if the given stream has enough frames for decoding. * Return false otherwise diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 27c49462..e0593ff6 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -13,8 +13,9 @@ namespace sherpa_onnx { class OnlineStream::Impl { public: - explicit Impl(const FeatureExtractorConfig &config) - : feat_extractor_(config) {} + explicit Impl(const FeatureExtractorConfig &config, + ContextGraphPtr context_graph) + : feat_extractor_(config), context_graph_(context_graph) {} void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); @@ -54,16 +55,21 @@ class OnlineStream::Impl { std::vector &GetStates() { return states_; } + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + private: FeatureExtractor feat_extractor_; + /// For contextual-biasing + ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling int32_t start_frame_index_ = 0; // never reset OnlineTransducerDecoderResult result_; std::vector states_; }; -OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) - : impl_(std::make_unique(config)) {} +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr */) + : impl_(std::make_unique(config, context_graph)) {} OnlineStream::~OnlineStream() = default; @@ -109,4 +115,8 @@ std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } +const ContextGraphPtr &OnlineStream::GetContextGraph() const { + return impl_->GetContextGraph(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index bc1935da..60dce950 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -9,6 +9,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" @@ -16,7 +17,8 @@ namespace sherpa_onnx { class OnlineStream { public: - explicit OnlineStream(const FeatureExtractorConfig &config = {}); + explicit OnlineStream(const FeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = nullptr); ~OnlineStream(); /** @@ -71,6 +73,13 @@ class OnlineStream { void SetStates(std::vector states); std::vector &GetStates(); + /** + * Get the context graph corresponding to this stream. + * + * @return Return the context graph for this stream. + */ + const ContextGraphPtr &GetContextGraph() const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index dcfa363b..68a8fae4 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -9,6 +9,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult { OnlineTransducerDecoderResult &&other); }; +class OnlineStream; class OnlineTransducerDecoder { public: virtual ~OnlineTransducerDecoder() = default; @@ -76,6 +78,26 @@ class OnlineTransducerDecoder { virtual void Decode(Ort::Value encoder_out, std::vector *result) = 0; + /** Run transducer beam search given the output from the encoder model. + * + * Note: Currently this interface is for contextual-biasing feature which + * needs a ContextGraph owned by the OnlineStream. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param ss A list of OnlineStreams. + * @param result It is modified in-place. + * + * @note There is no need to pass encoder_out_length here since for the + * online decoding case, each utterance has the same number of frames + * and there are no paddings. + */ + virtual void Decode(Ort::Value encoder_out, OnlineStream **ss, + std::vector *result) { + SHERPA_ONNX_LOGE( + "This interface is for OnlineTransducerModifiedBeamSearchDecoder."); + exit(-1); + } + // used for endpointing. We need to keep decoder_out after reset virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} }; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index cbf01a0a..7e2a4a97 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { @@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( void OnlineTransducerModifiedBeamSearchDecoder::Decode( Ort::Value encoder_out, std::vector *result) { + Decode(std::move(encoder_out), nullptr, result); +} + +void OnlineTransducerModifiedBeamSearchDecoder::Decode( + Ort::Value encoder_out, OnlineStream **ss, + std::vector *result) { std::vector encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); @@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } int32_t batch_size = static_cast(encoder_out_shape[0]); + int32_t num_frames = static_cast(encoder_out_shape[1]); int32_t vocab_size = model_->VocabSize(); @@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( Hypothesis new_hyp = prev[hyp_index]; const float prev_lm_log_prob = new_hyp.lm_log_prob; + float context_score = 0; + auto context_state = new_hyp.context_state; + if (new_token != 0) { new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t + frame_offset); new_hyp.num_trailing_blanks = 0; + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( + context_state, new_token); + context_score = context_res.first; + new_hyp.context_state = context_res.second; + } if (lm_) { lm_->ComputeLMScore(lm_scale_, &new_hyp); } } else { ++new_hyp.num_trailing_blanks; } - new_hyp.log_prob = - p_logprob[k] - prev_lm_log_prob; // log_prob only includes the + new_hyp.log_prob = p_logprob[k] + context_score - + prev_lm_log_prob; // log_prob only includes the // score of the transducer hyps.Add(std::move(new_hyp)); } // for (auto k : topk) diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 5fbf6a31..d05c5167 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -9,6 +9,7 @@ #include #include "sherpa-onnx/csrc/online-lm.h" +#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder void Decode(Ort::Value encoder_out, std::vector *result) override; + void Decode(Ort::Value encoder_out, OnlineStream **ss, + std::vector *result) override; + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; private: diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 54d97e80..02ab95e6 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), + const EndpointConfig &, bool, const std::string &, int32_t, + float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths")) + py::arg("max_active_paths"), py::arg("context_score")) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("context_score", &PyClass::context_score) .def("__str__", &PyClass::ToString); } @@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) { using PyClass = OnlineRecognizer; py::class_(*m, "OnlineRecognizer") .def(py::init(), py::arg("config")) - .def("create_stream", &PyClass::CreateStream) + .def("create_stream", + [](const PyClass &self) { return self.CreateStream(); }) + .def( + "create_stream", + [](PyClass &self, + const std::vector> &contexts_list) { + return self.CreateStream(contexts_list); + }, + py::arg("contexts_list")) .def("is_ready", &PyClass::IsReady) .def("decode_stream", &PyClass::DecodeStream) .def("decode_streams", diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 48dc74e7..c981bc04 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Xiaomi Corporation from pathlib import Path -from typing import List +from typing import List, Optional from _sherpa_onnx import ( EndpointConfig, @@ -39,6 +39,7 @@ class OnlineRecognizer(object): rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", max_active_paths: int = 4, + context_score: float = 1.5, provider: str = "cpu", ): """ @@ -124,13 +125,17 @@ class OnlineRecognizer(object): enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, max_active_paths=max_active_paths, + context_score=context_score, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config - def create_stream(self): - return self.recognizer.create_stream() + def create_stream(self, contexts_list : Optional[List[List[int]]] = None): + if contexts_list is None: + return self.recognizer.create_stream() + else: + return self.recognizer.create_stream(contexts_list) def decode_stream(self, s: OnlineStream): self.recognizer.decode_stream(s)