diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index b7eb9f26..6a49bad3 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -51,8 +51,13 @@ struct Hypothesis { // LM log prob if any. double lm_log_prob = 0; - // the nn lm score for next token given the current ys + // the nn lm score for next token given the current ys, + // when using shallow fusion CopyableOrtValue nn_lm_scores; + + // cur scored tokens by RNN LM, when rescoring + int32_t cur_scored_pos = 0; + // the nn lm states std::vector nn_lm_states; diff --git a/sherpa-onnx/csrc/online-lm-config.cc b/sherpa-onnx/csrc/online-lm-config.cc index 42990f72..9611c7f3 100644 --- a/sherpa-onnx/csrc/online-lm-config.cc +++ b/sherpa-onnx/csrc/online-lm-config.cc @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) { "Number of threads to run the neural network of LM model"); po->Register("lm-provider", &lm_provider, "Specify a provider to LM model use: cpu, cuda, coreml"); + po->Register("lm-shallow-fusion", &shallow_fusion, + "Boolean whether to use shallow fusion or rescore."); } bool OnlineLMConfig::Validate() const { @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const { os << "OnlineLMConfig("; os << "model=\"" << model << "\", "; - os << "scale=" << scale << ")"; + os << "scale=" << scale << ", "; + os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-lm-config.h b/sherpa-onnx/csrc/online-lm-config.h index 16d7b088..8d5b1670 100644 --- a/sherpa-onnx/csrc/online-lm-config.h +++ b/sherpa-onnx/csrc/online-lm-config.h @@ -18,15 +18,18 @@ struct OnlineLMConfig { float scale = 0.5; int32_t lm_num_threads = 1; std::string lm_provider = "cpu"; + // enable shallow fusion + bool shallow_fusion = true; OnlineLMConfig() = default; OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, - const std::string &lm_provider) + const std::string &lm_provider, bool shallow_fusion) : model(model), scale(scale), lm_num_threads(lm_num_threads), - lm_provider(lm_provider) {} + lm_provider(lm_provider), + shallow_fusion(shallow_fusion) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h index 6c73f46c..ffd48c56 100644 --- a/sherpa-onnx/csrc/online-lm.h +++ b/sherpa-onnx/csrc/online-lm.h @@ -21,13 +21,17 @@ class OnlineLM { static std::unique_ptr Create(const OnlineLMConfig &config); - virtual std::pair> GetInitStates() = 0; + // init states for classic rescore + virtual std::vector GetInitStates() = 0; - /** ScoreToken a batch of sentences. + // init states for shallow fusion + virtual std::pair> GetInitStatesSF() = 0; + + /** ScoreToken a batch of sentences (shallow fusion). * * @param x A 2-D tensor of shape (N, 1) with data type int64. * @param states It contains the states for the LM model - * @return Return a pair containingo + * @return Return a pair containing * - log_prob of NN LM * - updated states * @@ -35,13 +39,23 @@ class OnlineLM { virtual std::pair> ScoreToken( Ort::Value x, std::vector states) = 0; - /** This function updates lm_lob_prob and nn_lm_scores of hyp + /** This function updates hyp.lm_log_prob of hyps (classic rescore). + * + * @param scale LM score + * @param context_size Context size of the transducer decoder model + * @param hyps It is changed in-place. + * + */ + virtual void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) = 0; + + /** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion). * * @param scale LM score * @param hyps It is changed in-place. * */ - virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0; + virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 2bea765c..ab1e165f 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_, config_.blank_penalty, + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, + config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_, config_.blank_penalty, + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, + config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 5f938529..1b13d3a2 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" @@ -27,9 +28,10 @@ class OnlineRnnLM::Impl { Init(config); } - void ComputeLMScore(float scale, Hypothesis *hyp) { + // shallow fusion scoring function + void ComputeLMScoreSF(float scale, Hypothesis *hyp) { if (hyp->nn_lm_states.empty()) { - auto init_states = GetInitStates(); + auto init_states = GetInitStatesSF(); hyp->nn_lm_scores.value = std::move(init_states.first); hyp->nn_lm_states = Convert(std::move(init_states.second)); } @@ -49,6 +51,52 @@ class OnlineRnnLM::Impl { hyp->nn_lm_states = Convert(std::move(lm_out.second)); } + // classic rescore function + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + Ort::AllocatorWithDefaultOptions allocator; + + for (auto &hyp : *hyps) { + for (auto &h_m : hyp) { + auto &h = h_m.second; + auto &ys = h.ys; + const int32_t token_num_in_chunk = + ys.size() - context_size - h.cur_scored_pos - 1; + + if (token_num_in_chunk < 1) { + continue; + } + + if (h.nn_lm_states.empty()) { + h.nn_lm_states = Convert(GetInitStates()); + } + + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { + std::array x_shape{1, token_num_in_chunk}; + + Ort::Value x = Ort::Value::CreateTensor( + allocator, x_shape.data(), x_shape.size()); + int64_t *p_x = x.GetTensorMutableData(); + std::copy(ys.begin() + context_size + h.cur_scored_pos, + ys.end() - 1, p_x); + + // streaming forward by NN LM + auto out = ScoreToken(std::move(x), + Convert(std::move(h.nn_lm_states))); + + // update NN LM score in hyp + const float *p_nll = out.first.GetTensorData(); + h.lm_log_prob = -scale * (*p_nll); + + // update NN LM states in hyp + h.nn_lm_states = Convert(std::move(out.second)); + + h.cur_scored_pos += token_num_in_chunk; + } + } + } + } + std::pair> ScoreToken( Ort::Value x, std::vector states) { std::array inputs = {std::move(x), std::move(states[0]), @@ -66,7 +114,8 @@ class OnlineRnnLM::Impl { return {std::move(out[0]), std::move(next_states)}; } - std::pair> GetInitStates() { + // get init states for shallow fusion + std::pair> GetInitStatesSF() { std::vector ans; ans.reserve(init_states_.size()); for (auto &s : init_states_) { @@ -75,6 +124,18 @@ class OnlineRnnLM::Impl { return {View(&init_scores_.value), std::move(ans)}; } + // get init states for classic rescore + std::vector GetInitStates() const { + std::vector ans; + ans.reserve(init_states_.size()); + + for (const auto &s : init_states_) { + ans.emplace_back(Clone(allocator_, &s)); + } + + return ans; + } + private: void Init(const OnlineLMConfig &config) { auto buf = ReadFile(config_.model); @@ -116,7 +177,8 @@ class OnlineRnnLM::Impl { states.push_back(std::move(c)); auto pair = ScoreToken(std::move(x), std::move(states)); - init_scores_.value = std::move(pair.first); + init_scores_.value = std::move(pair.first); // only used during + // shallow fusion init_states_ = std::move(pair.second); } @@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) OnlineRnnLM::~OnlineRnnLM() = default; -std::pair> OnlineRnnLM::GetInitStates() { +// classic rescore state init +std::vector OnlineRnnLM::GetInitStates() { return impl_->GetInitStates(); } +// shallow fusion state init +std::pair> OnlineRnnLM::GetInitStatesSF() { + return impl_->GetInitStatesSF(); +} + std::pair> OnlineRnnLM::ScoreToken( Ort::Value x, std::vector states) { return impl_->ScoreToken(std::move(x), std::move(states)); } -void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) { - return impl_->ComputeLMScore(scale, hyp); +// classic rescore scores +void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + return impl_->ComputeLMScore(scale, context_size, hyps); } +// shallow fusion scores +void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { + return impl_->ComputeLMScoreSF(scale, hyp); +} + + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-rnn-lm.h b/sherpa-onnx/csrc/online-rnn-lm.h index dee17f73..2e7ad70e 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.h +++ b/sherpa-onnx/csrc/online-rnn-lm.h @@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM { explicit OnlineRnnLM(const OnlineLMConfig &config); - std::pair> GetInitStates() override; + // init scores for classic rescore + std::vector GetInitStates() override; - /** ScoreToken a batch of sentences. + // init scores for shallow fusion + std::pair> GetInitStatesSF() override; + + /** ScoreToken a batch of sentences (shallow fusion). * * @param x A 2-D tensor of shape (N, L) with data type int64. * @param states It contains the states for the LM model - * @return Return a pair containingo + * @return Return a pair containing * - log_prob of NN LM * - updated states * @@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM { std::pair> ScoreToken( Ort::Value x, std::vector states) override; - /** This function updates lm_lob_prob and nn_lm_scores of hyp + /** This function updates hyp.lm_lob_prob of hyps (classic rescore). + * + * @param scale LM score + * @param context_size Context size of the transducer decoder model + * @param hyps It is changed in-place. + * + */ + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) override; + + /** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion). * * @param scale LM score * @param hyps It is changed in-place. * */ - void ComputeLMScore(float scale, Hypothesis *hyp) override; + void ComputeLMScoreSF(float scale, Hypothesis *hyp) override; private: class Impl; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index a49a8bbd..5ad11f77 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( // add log_prob of each hypothesis to p_logprob before taking top_k for (int32_t i = 0; i != num_hyps; ++i) { - float log_prob = prev[i].log_prob + prev[i].lm_log_prob; + float log_prob = prev[i].log_prob; + if (lm_ && shallow_fusion_) { + log_prob += prev[i].lm_log_prob; + } + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { *p_logprob += log_prob; } @@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( context_score = std::get<0>(context_res); new_hyp.context_state = std::get<1>(context_res); } - if (lm_) { - lm_->ComputeLMScore(lm_scale_, &new_hyp); + if (lm_ && shallow_fusion_) { + lm_->ComputeLMScoreSF(lm_scale_, &new_hyp); } } else { ++new_hyp.num_trailing_blanks; } - new_hyp.log_prob = p_logprob[k] + context_score - + if (lm_ && shallow_fusion_) { + new_hyp.log_prob = p_logprob[k] + context_score - prev_lm_log_prob; // log_prob only includes the // score of the transducer + } else { + new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM + // previous token + // score is ignored + } + // export the per-token log scores if (new_token != 0 && new_token != unk_id_) { float y_prob = logit_with_temperature[start * vocab_size + k]; new_hyp.ys_probs.push_back(y_prob); - if (lm_) { // export only when LM is used + if (lm_ && shallow_fusion_) { // export only if + // LM shallow fusion is used float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; + if (lm_scale_ != 0.0) { lm_prob /= lm_scale_; // remove lm-scale } @@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } // for (int32_t b = 0; b != batch_size; ++b) } // for (int32_t t = 0; t != num_frames; ++t) + // classic lm rescore + if (lm_ && !shallow_fusion_) { + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); + } + for (int32_t b = 0; b != batch_size; ++b) { auto &hyps = cur[b]; auto best_hyp = hyps.GetMostProbable(true); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 839aa768..6dea71be 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, OnlineLM *lm, int32_t max_active_paths, - float lm_scale, int32_t unk_id, + float lm_scale, + bool shallow_fusion, + int32_t unk_id, float blank_penalty, float temperature_scale) : model_(model), lm_(lm), max_active_paths_(max_active_paths), lm_scale_(lm_scale), + shallow_fusion_(shallow_fusion), unk_id_(unk_id), blank_penalty_(blank_penalty), temperature_scale_(temperature_scale) {} @@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr + bool shallow_fusion_; // used only when lm_ is not nullptr int32_t unk_id_; float blank_penalty_; float temperature_scale_; diff --git a/sherpa-onnx/python/csrc/online-lm-config.cc b/sherpa-onnx/python/csrc/online-lm-config.cc index 56da7399..0e9a0385 100644 --- a/sherpa-onnx/python/csrc/online-lm-config.cc +++ b/sherpa-onnx/python/csrc/online-lm-config.cc @@ -13,13 +13,16 @@ namespace sherpa_onnx { void PybindOnlineLMConfig(py::module *m) { using PyClass = OnlineLMConfig; py::class_(*m, "OnlineLMConfig") - .def(py::init(), + .def(py::init(), py::arg("model") = "", py::arg("scale") = 0.5f, - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", + py::arg("shallow_fusion") = true) .def_readwrite("model", &PyClass::model) .def_readwrite("scale", &PyClass::scale) .def_readwrite("lm_provider", &PyClass::lm_provider) .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) + .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 779ba6e7..321f1cdf 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -64,6 +64,7 @@ class OnlineRecognizer(object): bpe_vocab: str = "", lm: str = "", lm_scale: float = 0.1, + lm_shallow_fusion: bool = True, temperature_scale: float = 2.0, debug: bool = False, rule_fsts: str = "", @@ -274,6 +275,7 @@ class OnlineRecognizer(object): lm_config = OnlineLMConfig( model=lm, scale=lm_scale, + shallow_fusion=lm_shallow_fusion, ) recognizer_config = OnlineRecognizerConfig(