diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 223db1f8..4d06fa50 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -54,7 +54,7 @@ jobs: - name: Install Python dependencies shell: bash run: | - python3 -m pip install --upgrade pip numpy + python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 - name: Install sherpa-onnx shell: bash diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index 670fbde6..98ead3f9 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -43,9 +43,10 @@ import argparse import time import wave from pathlib import Path -from typing import Tuple +from typing import List, Tuple import numpy as np +import sentencepiece as spm import sherpa_onnx @@ -60,6 +61,47 @@ def get_args(): help="Path to tokens.txt", ) + parser.add_argument( + "--bpe-model", + type=str, + default="", + help=""" + Path to bpe.model, + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="char", + help=""" + The type of modeling unit. + Valid values are bpe, bpe+char, char. + Note: the char here means characters in CJK languages. + """, + ) + + parser.add_argument( + "--contexts", + type=str, + default="", + help=""" + The context list, it is a string containing some words/phrases separated + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". + """, + ) + + parser.add_argument( + "--context-score", + type=float, + default=1.5, + help=""" + The context score of each token for biasing word/phrase. Used only if + --contexts is given. + """, + ) + parser.add_argument( "--encoder", default="", @@ -153,6 +195,24 @@ def assert_file_exists(filename: str): ) +def encode_contexts(args, contexts: List[str]) -> List[List[int]]: + sp = None + if "bpe" in args.modeling_unit: + assert_file_exists(args.bpe_model) + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + tokens = {} + with open(args.tokens, "r", encoding="utf-8") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, len(toks) + assert toks[0] not in tokens, f"Duplicate token: {toks} " + tokens[toks[0]] = int(toks[1]) + return sherpa_onnx.encode_contexts( + modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens + ) + + def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ Args: @@ -182,10 +242,17 @@ def main(): args = get_args() assert_file_exists(args.tokens) assert args.num_threads > 0, args.num_threads + + contexts_list = [] if args.encoder: assert len(args.paraformer) == 0, args.paraformer assert len(args.nemo_ctc) == 0, args.nemo_ctc + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] + if contexts: + print(f"Contexts list: {contexts}") + contexts_list = encode_contexts(args, contexts) + assert_file_exists(args.encoder) assert_file_exists(args.decoder) assert_file_exists(args.joiner) @@ -199,6 +266,7 @@ def main(): sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, + context_score=args.context_score, debug=args.debug, ) elif args.paraformer: @@ -238,8 +306,12 @@ def main(): samples, sample_rate = read_wave(wave_filename) duration = len(samples) / sample_rate total_duration += duration - - s = recognizer.create_stream() + if contexts_list: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.nemo_ctc) == 0, args.nemo_ctc + s = recognizer.create_stream(contexts_list=contexts_list) + else: + s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) streams.append(s) diff --git a/setup.py b/setup.py index 99f3fa6b..a4d74039 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f: install_requires = [ "numpy", + "sentencepiece==0.1.96", ] diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 82a831fe..877c31ed 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -12,6 +12,7 @@ endif() set(sources cat.cc + context-graph.cc endpoint.cc features.cc file-utils.cc @@ -248,6 +249,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + context-graph-test.cc packed-sequence-test.cc pad-sequence-test.cc slice-test.cc diff --git a/sherpa-onnx/csrc/context-graph-test.cc b/sherpa-onnx/csrc/context-graph-test.cc new file mode 100644 index 00000000..97d03443 --- /dev/null +++ b/sherpa-onnx/csrc/context-graph-test.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/csrc/context-graph-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/context-graph.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace sherpa_onnx { + +TEST(ContextGraph, TestBasic) { + std::vector contexts_str( + {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); + std::vector> contexts; + for (int32_t i = 0; i < contexts_str.size(); ++i) { + contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); + } + auto context_graph = ContextGraph(contexts, 1); + + auto queries = std::map{ + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, + {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + + for (const auto &iter : queries) { + float total_scores = 0; + auto state = context_graph.Root(); + for (auto q : iter.first) { + auto res = context_graph.ForwardOneStep(state, q); + total_scores += res.first; + state = res.second; + } + auto res = context_graph.Finalize(state); + EXPECT_EQ(res.second->token, -1); + total_scores += res.first; + EXPECT_EQ(total_scores, iter.second); + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/context-graph.cc b/sherpa-onnx/csrc/context-graph.cc new file mode 100644 index 00000000..bc3a1e3e --- /dev/null +++ b/sherpa-onnx/csrc/context-graph.cc @@ -0,0 +1,105 @@ +// sherpa-onnx/csrc/context-graph.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/context-graph.h" + +#include +#include +#include + +namespace sherpa_onnx { +void ContextGraph::Build( + const std::vector> &token_ids) const { + for (int32_t i = 0; i < token_ids.size(); ++i) { + auto node = root_.get(); + for (int32_t j = 0; j < token_ids[i].size(); ++j) { + int32_t token = token_ids[i][j]; + if (0 == node->next.count(token)) { + bool is_end = j == token_ids[i].size() - 1; + node->next[token] = std::make_unique( + token, context_score_, node->node_score + context_score_, + is_end ? 0 : node->local_node_score + context_score_, is_end); + } + node = node->next[token].get(); + } + } + FillFailOutput(); +} + +std::pair ContextGraph::ForwardOneStep( + const ContextState *state, int32_t token) const { + const ContextState *node; + float score; + if (1 == state->next.count(token)) { + node = state->next.at(token).get(); + score = node->token_score; + if (state->is_end) score += state->node_score; + } else { + node = state->fail; + while (0 == node->next.count(token)) { + node = node->fail; + if (-1 == node->token) break; // root + } + if (1 == node->next.count(token)) { + node = node->next.at(token).get(); + } + score = node->node_score - state->local_node_score; + } + SHERPA_ONNX_CHECK(nullptr != node); + float matched_score = 0; + auto output = node->output; + while (nullptr != output) { + matched_score += output->node_score; + output = output->output; + } + return std::make_pair(score + matched_score, node); +} + +std::pair ContextGraph::Finalize( + const ContextState *state) const { + float score = -state->node_score; + if (state->is_end) { + score = 0; + } + return std::make_pair(score, root_.get()); +} + +void ContextGraph::FillFailOutput() const { + std::queue node_queue; + for (auto &kv : root_->next) { + kv.second->fail = root_.get(); + node_queue.push(kv.second.get()); + } + while (!node_queue.empty()) { + auto current_node = node_queue.front(); + node_queue.pop(); + for (auto &kv : current_node->next) { + auto fail = current_node->fail; + if (1 == fail->next.count(kv.first)) { + fail = fail->next.at(kv.first).get(); + } else { + fail = fail->fail; + while (0 == fail->next.count(kv.first)) { + fail = fail->fail; + if (-1 == fail->token) break; + } + if (1 == fail->next.count(kv.first)) + fail = fail->next.at(kv.first).get(); + } + kv.second->fail = fail; + // fill the output arc + auto output = fail; + while (!output->is_end) { + output = output->fail; + if (-1 == output->token) { + output = nullptr; + break; + } + } + kv.second->output = output; + node_queue.push(kv.second.get()); + } + } +} +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/context-graph.h b/sherpa-onnx/csrc/context-graph.h new file mode 100644 index 00000000..db16ce66 --- /dev/null +++ b/sherpa-onnx/csrc/context-graph.h @@ -0,0 +1,66 @@ +// sherpa-onnx/csrc/context-graph.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ +#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/log.h" + +namespace sherpa_onnx { + +class ContextGraph; +using ContextGraphPtr = std::shared_ptr; + +struct ContextState { + int32_t token; + float token_score; + float node_score; + float local_node_score; + bool is_end; + std::unordered_map> next; + const ContextState *fail = nullptr; + const ContextState *output = nullptr; + + ContextState() = default; + ContextState(int32_t token, float token_score, float node_score, + float local_node_score, bool is_end) + : token(token), + token_score(token_score), + node_score(node_score), + local_node_score(local_node_score), + is_end(is_end) {} +}; + +class ContextGraph { + public: + ContextGraph() = default; + ContextGraph(const std::vector> &token_ids, + float context_score) + : context_score_(context_score) { + root_ = std::make_unique(-1, 0, 0, 0, false); + root_->fail = root_.get(); + Build(token_ids); + } + + std::pair ForwardOneStep( + const ContextState *state, int32_t token_id) const; + std::pair Finalize( + const ContextState *state) const; + + const ContextState *Root() const { return root_.get(); } + + private: + float context_score_; + std::unique_ptr root_; + void Build(const std::vector> &token_ids) const; + void FillFailOutput() const; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 98cc50f2..29a64340 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -14,6 +14,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -39,11 +40,18 @@ struct Hypothesis { // the nn lm states std::vector nn_lm_states; + const ContextState *context_state; + + // TODO(fangjun): Make it configurable + // the minimum of tokens in a chunk for streaming RNN LM + int32_t lm_rescore_min_chunk = 2; // a const + int32_t num_trailing_blanks = 0; Hypothesis() = default; - Hypothesis(const std::vector &ys, double log_prob) - : ys(ys), log_prob(log_prob) {} + Hypothesis(const std::vector &ys, double log_prob, + const ContextState *context_state = nullptr) + : ys(ys), log_prob(log_prob), context_state(context_state) {} double TotalLogProb() const { return log_prob + lm_log_prob; } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h index 065be58e..43a44abf 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -6,7 +6,9 @@ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ #include +#include +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-stream.h" @@ -19,6 +21,12 @@ class OfflineRecognizerImpl { virtual ~OfflineRecognizerImpl() = default; + virtual std::unique_ptr CreateStream( + const std::vector> &context_list) const { + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); + exit(-1); + } + virtual std::unique_ptr CreateStream() const = 0; virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index d8360dce..7245497d 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer.h" @@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } } + std::unique_ptr CreateStream( + const std::vector> &context_list) const override { + // We create context_graph at this level, because we might have default + // context_graph(will be added later if needed) that belongs to the whole + // model rather than each stream. + auto context_graph = + std::make_shared(context_list, config_.context_score); + return std::make_unique(config_.feat_config, context_graph); + } + std::unique_ptr CreateStream() const override { return std::make_unique(config_.feat_config); } @@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { -23.025850929940457f); auto t = model_->RunEncoder(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + auto results = + decoder_->Decode(std::move(t.first), std::move(t.second), ss, n); int32_t frame_shift_ms = 10; for (int32_t i = 0; i != n; ++i) { diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index c5daa17e..3fe53c93 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register("max-active-paths", &max_active_paths, "Used only when decoding_method is modified_beam_search"); + po->Register("context-score", &context_score, + "The bonus score for each token in context word/phrase. " + "Used only when decoding_method is modified_beam_search"); } bool OfflineRecognizerConfig::Validate() const { @@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const { os << "model_config=" << model_config.ToString() << ", "; os << "lm_config=" << lm_config.ToString() << ", "; os << "decoding_method=\"" << decoding_method << "\", "; - os << "max_active_paths=" << max_active_paths << ")"; + os << "max_active_paths=" << max_active_paths << ", "; + os << "context_score=" << context_score << ")"; return os.str(); } @@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) OfflineRecognizer::~OfflineRecognizer() = default; +std::unique_ptr OfflineRecognizer::CreateStream( + const std::vector> &context_list) const { + return impl_->CreateStream(context_list); +} + std::unique_ptr OfflineRecognizer::CreateStream() const { return impl_->CreateStream(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index d6fcb390..6dfa4e6f 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -26,6 +26,7 @@ struct OfflineRecognizerConfig { std::string decoding_method = "greedy_search"; int32_t max_active_paths = 4; + float context_score = 1.5; // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search @@ -34,12 +35,13 @@ struct OfflineRecognizerConfig { const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, const std::string &decoding_method, - int32_t max_active_paths) + int32_t max_active_paths, float context_score) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), decoding_method(decoding_method), - max_active_paths(max_active_paths) {} + max_active_paths(max_active_paths), + context_score(context_score) {} void Register(ParseOptions *po); bool Validate() const; @@ -58,6 +60,10 @@ class OfflineRecognizer { /// Create a stream for decoding. std::unique_ptr CreateStream() const; + /// Create a stream for decoding. + std::unique_ptr CreateStream( + const std::vector> &context_list) const; + /** Decode a single stream * * @param s The stream to decode. diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index a84a15a9..15ed0389 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const { class OfflineStream::Impl { public: - explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { + explicit Impl(const OfflineFeatureExtractorConfig &config, + ContextGraphPtr context_graph) + : config_(config), context_graph_(context_graph) { opts_.frame_opts.dither = 0; opts_.frame_opts.snip_edges = false; opts_.frame_opts.samp_freq = config.sampling_rate; @@ -152,6 +154,8 @@ class OfflineStream::Impl { const OfflineRecognitionResult &GetResult() const { return r_; } + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + private: void NemoNormalizeFeatures(float *p, int32_t num_frames, int32_t feature_dim) const { @@ -189,11 +193,13 @@ class OfflineStream::Impl { std::unique_ptr fbank_; knf::FbankOptions opts_; OfflineRecognitionResult r_; + ContextGraphPtr context_graph_; }; OfflineStream::OfflineStream( - const OfflineFeatureExtractorConfig &config /*= {}*/) - : impl_(std::make_unique(config)) {} + const OfflineFeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr*/) + : impl_(std::make_unique(config, context_graph)) {} OfflineStream::~OfflineStream() = default; @@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) { impl_->SetResult(r); } +const ContextGraphPtr &OfflineStream::GetContextGraph() const { + return impl_->GetContextGraph(); +} + const OfflineRecognitionResult &OfflineStream::GetResult() const { return impl_->GetResult(); } diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index b1bed47e..a21496bd 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { @@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig { class OfflineStream { public: - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = nullptr); ~OfflineStream(); /** @@ -96,6 +98,9 @@ class OfflineStream { /** Get the recognition result of this stream */ const OfflineRecognitionResult &GetResult() const; + /** Get the ContextGraph of this stream */ + const ContextGraphPtr &GetContextGraph() const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/offline-transducer-decoder.h b/sherpa-onnx/csrc/offline-transducer-decoder.h index 898fc29c..36d93e44 100644 --- a/sherpa-onnx/csrc/offline-transducer-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-decoder.h @@ -8,6 +8,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-stream.h" namespace sherpa_onnx { @@ -33,7 +34,8 @@ class OfflineTransducerDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc index 6432ff94..d9ef5f8d 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -16,7 +16,9 @@ namespace sherpa_onnx { std::vector OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, - Ort::Value encoder_out_length) { + Ort::Value encoder_out_length, + OfflineStream **ss /*= nullptr*/, + int32_t n /*= 0*/) { PackedSequence packed_encoder_out = PackPaddedSequence( model_->Allocator(), &encoder_out, &encoder_out_length); diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h index a0175d5c..ff172250 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h @@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { : model_(model) {} std::vector Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length) override; + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; private: OfflineTransducerModel *model_; // Not owned diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index d9660684..1401a839 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -8,7 +8,9 @@ #include #include +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/packed-sequence.h" #include "sherpa-onnx/csrc/slice.h" @@ -17,23 +19,39 @@ namespace sherpa_onnx { std::vector OfflineTransducerModifiedBeamSearchDecoder::Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length) { + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) { PackedSequence packed_encoder_out = PackPaddedSequence( model_->Allocator(), &encoder_out, &encoder_out_length); int32_t batch_size = static_cast(packed_encoder_out.sorted_indexes.size()); + if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n); + int32_t vocab_size = model_->VocabSize(); int32_t context_size = model_->ContextSize(); std::vector blanks(context_size, 0); - Hypotheses blank_hyp({{blanks, 0}}); std::deque finalized; - std::vector cur(batch_size, blank_hyp); + std::vector cur; std::vector prev; + std::vector context_graphs(batch_size, nullptr); + + for (int32_t i = 0; i < batch_size; ++i) { + const ContextState *context_state; + if (ss != nullptr) { + context_graphs[i] = + ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph(); + if (context_graphs[i] != nullptr) + context_state = context_graphs[i]->Root(); + } + Hypotheses blank_hyp({{blanks, 0, context_state}}); + cur.emplace_back(std::move(blank_hyp)); + } + int32_t start = 0; int32_t t = 0; for (auto n : packed_encoder_out.batch_sizes) { @@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( int32_t new_token = k % vocab_size; Hypothesis new_hyp = prev[hyp_index]; + float context_score = 0; + auto context_state = new_hyp.context_state; if (new_token != 0) { // blank id is fixed to 0 new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t); + if (context_graphs[i] != nullptr) { + auto context_res = + context_graphs[i]->ForwardOneStep(context_state, new_token); + context_score = context_res.first; + new_hyp.context_state = context_res.second; + } } - new_hyp.log_prob = p_logprob[k]; + new_hyp.log_prob = p_logprob[k] + context_score; hyps.Add(std::move(new_hyp)); } // for (auto k : topk) p_logprob += (end - start) * vocab_size; @@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( cur.push_back(std::move(h)); } + // Finalize context biasing matching.. + for (int32_t i = 0; i < cur.size(); ++i) { + for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) { + if (context_graphs[i] != nullptr) { + auto context_res = + context_graphs[i]->Finalize(iter->second.context_state); + iter->second.log_prob += context_res.first; + iter->second.context_state = context_res.second; + } + } + } + if (lm_) { // use LM for rescoring lm_->ComputeLMScore(lm_scale_, context_size, &cur); diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h index 5f40dc29..89b277c6 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h @@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder lm_scale_(lm_scale) {} std::vector Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length) override; + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; private: OfflineTransducerModel *model_; // Not owned diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 7458181c..462d8ba3 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { py::class_(*m, "OfflineRecognizerConfig") .def(py::init(), + const std::string &, int32_t, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("decoding_method") = "greedy_search", - py::arg("max_active_paths") = 4) + py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("context_score", &PyClass::context_score) .def("__str__", &PyClass::ToString); } @@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) { using PyClass = OfflineRecognizer; py::class_(*m, "OfflineRecognizer") .def(py::init(), py::arg("config")) - .def("create_stream", &PyClass::CreateStream) + .def("create_stream", + [](const PyClass &self) { return self.CreateStream(); }) + .def( + "create_stream", + [](PyClass &self, + const std::vector> &contexts_list) { + return self.CreateStream(contexts_list); + }, + py::arg("contexts_list")) .def("decode_stream", &PyClass::DecodeStream) .def("decode_streams", - [](PyClass &self, std::vector ss) { + [](const PyClass &self, std::vector ss) { self.DecodeStreams(ss.data(), ss.size()); }); } diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 9865fb87..27c3e549 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,5 +1,12 @@ +from typing import Dict, List, Optional + from _sherpa_onnx import Display from .online_recognizer import OnlineRecognizer from .online_recognizer import OnlineStream from .offline_recognizer import OfflineRecognizer + +from .utils import encode_contexts + + + diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 5b3f2a3e..0e1a0494 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 by manyeyes from pathlib import Path -from typing import List +from typing import List, Optional from _sherpa_onnx import ( OfflineFeatureExtractorConfig, @@ -39,6 +39,7 @@ class OfflineRecognizer(object): sample_rate: int = 16000, feature_dim: int = 80, decoding_method: str = "greedy_search", + context_score: float = 1.5, debug: bool = False, provider: str = "cpu", ): @@ -96,6 +97,7 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + context_score=context_score, ) self.recognizer = _Recognizer(recognizer_config) return self @@ -216,8 +218,11 @@ class OfflineRecognizer(object): self.recognizer = _Recognizer(recognizer_config) return self - def create_stream(self): - return self.recognizer.create_stream() + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): + if contexts_list is None: + return self.recognizer.create_stream() + else: + return self.recognizer.create_stream(contexts_list) def decode_stream(self, s: OfflineStream): self.recognizer.decode_stream(s) diff --git a/sherpa-onnx/python/sherpa_onnx/utils.py b/sherpa-onnx/python/sherpa_onnx/utils.py new file mode 100644 index 00000000..82419365 --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/utils.py @@ -0,0 +1,74 @@ +from typing import Dict, List, Optional + + +def encode_contexts( + modeling_unit: str, + contexts: List[str], + sp: Optional["SentencePieceProcessor"] = None, + tokens_table: Optional[Dict[str, int]] = None, +) -> List[List[int]]: + """ + Encode the given contexts (a list of string) to a list of a list of token ids. + + Args: + modeling_unit: + The valid values are bpe, char, bpe+char. + Note: char here means characters in CJK languages, not English like languages. + contexts: + The given contexts list (a list of string). + sp: + An instance of SentencePieceProcessor. + tokens_table: + The tokens_table containing the tokens and the corresponding ids. + Returns: + Return the contexts_list, it is a list of a list of token ids. + """ + contexts_list = [] + if "bpe" in modeling_unit: + assert sp is not None + if "char" in modeling_unit: + assert tokens_table is not None + assert len(tokens_table) > 0, len(tokens_table) + + if "char" == modeling_unit: + for context in contexts: + assert ' ' not in context + ids = [ + tokens_table[txt] if txt in tokens_table else tokens_table[""] + for txt in context + ] + contexts_list.append(ids) + elif "bpe" == modeling_unit: + contexts_list = sp.encode(contexts, out_type=int) + else: + assert modeling_unit == "bpe+char", modeling_unit + + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r"([\u4e00-\u9fff])") + for context in contexts: + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(context.upper()) + mix_chars = [w for w in chars if len(w.strip()) > 0] + ids = [] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + ids.append( + tokens_table[ch_or_w] + if ch_or_w in tokens_table + else tokens_table[""] + ) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + for p in sp.encode_as_pieces(ch_or_w): + ids.append( + tokens_table[p] + if p in tokens_table + else tokens_table[""] + ) + contexts_list.append(ids) + return contexts_list