From 824b0809a47a412a2b879c8970fb999840eed728 Mon Sep 17 00:00:00 2001 From: PF Luo Date: Wed, 10 May 2023 22:30:57 +0800 Subject: [PATCH] add shallow fusion (#147) --- sherpa-onnx/csrc/hypothesis.h | 8 +-- sherpa-onnx/csrc/online-lm.cc | 72 ------------------- sherpa-onnx/csrc/online-lm.h | 28 ++++---- sherpa-onnx/csrc/online-rnn-lm.cc | 57 ++++++++++----- sherpa-onnx/csrc/online-rnn-lm.h | 20 ++++-- ...transducer-modified-beam-search-decoder.cc | 12 ++-- sherpa-onnx/csrc/onnx-utils.cc | 22 ++++++ sherpa-onnx/csrc/onnx-utils.h | 8 ++- sherpa-onnx/csrc/sherpa-onnx.cc | 2 +- 9 files changed, 104 insertions(+), 125 deletions(-) diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index a0097f52..98cc50f2 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -34,13 +34,11 @@ struct Hypothesis { // LM log prob if any. double lm_log_prob = 0; - int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM + // the nn lm score for next token given the current ys + CopyableOrtValue nn_lm_scores; + // the nn lm states std::vector nn_lm_states; - // TODO(fangjun): Make it configurable - // the minimum of tokens in a chunk for streaming RNN LM - int32_t lm_rescore_min_chunk = 2; // a const - int32_t num_trailing_blanks = 0; Hypothesis() = default; diff --git a/sherpa-onnx/csrc/online-lm.cc b/sherpa-onnx/csrc/online-lm.cc index 11283e11..dfec00cc 100644 --- a/sherpa-onnx/csrc/online-lm.cc +++ b/sherpa-onnx/csrc/online-lm.cc @@ -13,80 +13,8 @@ namespace sherpa_onnx { -static std::vector Convert(std::vector values) { - std::vector ans; - ans.reserve(values.size()); - - for (auto &v : values) { - ans.emplace_back(std::move(v)); - } - - return ans; -} - -static std::vector Convert(std::vector values) { - std::vector ans; - ans.reserve(values.size()); - - for (auto &v : values) { - ans.emplace_back(std::move(v.value)); - } - - return ans; -} - std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { return std::make_unique(config); } -void OnlineLM::ComputeLMScore(float scale, int32_t context_size, - std::vector *hyps) { - Ort::AllocatorWithDefaultOptions allocator; - - for (auto &hyp : *hyps) { - for (auto &h_m : hyp) { - auto &h = h_m.second; - auto &ys = h.ys; - const int32_t token_num_in_chunk = - ys.size() - context_size - h.cur_scored_pos - 1; - - if (token_num_in_chunk < 1) { - continue; - } - - if (h.nn_lm_states.empty()) { - h.nn_lm_states = Convert(GetInitStates()); - } - - if (token_num_in_chunk >= h.lm_rescore_min_chunk) { - std::array x_shape{1, token_num_in_chunk}; - // shape of x and y are same - Ort::Value x = Ort::Value::CreateTensor( - allocator, x_shape.data(), x_shape.size()); - Ort::Value y = Ort::Value::CreateTensor( - allocator, x_shape.data(), x_shape.size()); - int64_t *p_x = x.GetTensorMutableData(); - int64_t *p_y = y.GetTensorMutableData(); - std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, - p_x); - std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(), - p_y); - - // streaming forward by NN LM - auto out = Rescore(std::move(x), std::move(y), - Convert(std::move(h.nn_lm_states))); - - // update NN LM score in hyp - const float *p_nll = out.first.GetTensorData(); - h.lm_log_prob = -scale * (*p_nll); - - // update NN LM states in hyp - h.nn_lm_states = Convert(std::move(out.second)); - - h.cur_scored_pos += token_num_in_chunk; - } - } - } -} - } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h index cc4a7de6..6c73f46c 100644 --- a/sherpa-onnx/csrc/online-lm.h +++ b/sherpa-onnx/csrc/online-lm.h @@ -21,29 +21,27 @@ class OnlineLM { static std::unique_ptr Create(const OnlineLMConfig &config); - virtual std::vector GetInitStates() = 0; + virtual std::pair> GetInitStates() = 0; - /** Rescore a batch of sentences. + /** ScoreToken a batch of sentences. * - * @param x A 2-D tensor of shape (N, L) with data type int64. - * @param y A 2-D tensor of shape (N, L) with data type int64. + * @param x A 2-D tensor of shape (N, 1) with data type int64. * @param states It contains the states for the LM model * @return Return a pair containingo - * - negative loglike + * - log_prob of NN LM * - updated states * - * Caution: It returns negative log likelihood (nll), not log likelihood */ - virtual std::pair> Rescore( - Ort::Value x, Ort::Value y, std::vector states) = 0; + virtual std::pair> ScoreToken( + Ort::Value x, std::vector states) = 0; - // This function updates hyp.lm_lob_prob of hyps. - // - // @param scale LM score - // @param context_size Context size of the transducer decoder model - // @param hyps It is changed in-place. - void ComputeLMScore(float scale, int32_t context_size, - std::vector *hyps); + /** This function updates lm_lob_prob and nn_lm_scores of hyp + * + * @param scale LM score + * @param hyps It is changed in-place. + * + */ + virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 611e0c40..ad8426fb 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -26,10 +26,33 @@ class OnlineRnnLM::Impl { Init(config); } - std::pair> Rescore( - Ort::Value x, Ort::Value y, std::vector states) { - std::array inputs = { - std::move(x), std::move(y), std::move(states[0]), std::move(states[1])}; + void ComputeLMScore(float scale, Hypothesis *hyp) { + if (hyp->nn_lm_states.empty()) { + auto init_states = GetInitStates(); + hyp->nn_lm_scores.value = std::move(init_states.first); + hyp->nn_lm_states = Convert(std::move(init_states.second)); + } + + // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob + const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData(); + hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale; + + // get lm scores for next tokens given the hyp->ys[:] and save to + // nn_lm_scores + std::array x_shape{1, 1}; + Ort::Value x = Ort::Value::CreateTensor(allocator_, x_shape.data(), + x_shape.size()); + *x.GetTensorMutableData() = hyp->ys.back(); + auto lm_out = + ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); + hyp->nn_lm_scores.value = std::move(lm_out.first); + hyp->nn_lm_states = Convert(std::move(lm_out.second)); + } + + std::pair> ScoreToken( + Ort::Value x, std::vector states) { + std::array inputs = {std::move(x), std::move(states[0]), + std::move(states[1])}; auto out = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), @@ -43,15 +66,13 @@ class OnlineRnnLM::Impl { return {std::move(out[0]), std::move(next_states)}; } - std::vector GetInitStates() const { + std::pair> GetInitStates() const { std::vector ans; ans.reserve(init_states_.size()); - for (const auto &s : init_states_) { ans.emplace_back(Clone(allocator_, &s)); } - - return ans; + return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; } private: @@ -86,19 +107,16 @@ class OnlineRnnLM::Impl { Fill(&h, 0); Fill(&c, 0); std::array x_shape{1, 1}; - // shape of x and y are same Ort::Value x = Ort::Value::CreateTensor(allocator_, x_shape.data(), x_shape.size()); - Ort::Value y = Ort::Value::CreateTensor(allocator_, x_shape.data(), - x_shape.size()); *x.GetTensorMutableData() = sos_id_; - *y.GetTensorMutableData() = sos_id_; std::vector states; states.push_back(std::move(h)); states.push_back(std::move(c)); - auto pair = Rescore(std::move(x), std::move(y), std::move(states)); + auto pair = ScoreToken(std::move(x), std::move(states)); + init_scores_.value = std::move(pair.first); init_states_ = std::move(pair.second); } @@ -116,6 +134,7 @@ class OnlineRnnLM::Impl { std::vector output_names_; std::vector output_names_ptr_; + CopyableOrtValue init_scores_; std::vector init_states_; int32_t rnn_num_layers_ = 2; @@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) OnlineRnnLM::~OnlineRnnLM() = default; -std::vector OnlineRnnLM::GetInitStates() { +std::pair> OnlineRnnLM::GetInitStates() { return impl_->GetInitStates(); } -std::pair> OnlineRnnLM::Rescore( - Ort::Value x, Ort::Value y, std::vector states) { - return impl_->Rescore(std::move(x), std::move(y), std::move(states)); +std::pair> OnlineRnnLM::ScoreToken( + Ort::Value x, std::vector states) { + return impl_->ScoreToken(std::move(x), std::move(states)); +} + +void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) { + return impl_->ComputeLMScore(scale, hyp); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-rnn-lm.h b/sherpa-onnx/csrc/online-rnn-lm.h index fcb2b17e..dee17f73 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.h +++ b/sherpa-onnx/csrc/online-rnn-lm.h @@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM { explicit OnlineRnnLM(const OnlineLMConfig &config); - std::vector GetInitStates() override; + std::pair> GetInitStates() override; - /** Rescore a batch of sentences. + /** ScoreToken a batch of sentences. * * @param x A 2-D tensor of shape (N, L) with data type int64. - * @param y A 2-D tensor of shape (N, L) with data type int64. * @param states It contains the states for the LM model * @return Return a pair containingo - * - negative loglike + * - log_prob of NN LM * - updated states * - * Caution: It returns negative log likelihood (nll), not log likelihood */ - std::pair> Rescore( - Ort::Value x, Ort::Value y, std::vector states) override; + std::pair> ScoreToken( + Ort::Value x, std::vector states) override; + + /** This function updates lm_lob_prob and nn_lm_scores of hyp + * + * @param scale LM score + * @param hyps It is changed in-place. + * + */ + void ComputeLMScore(float scale, Hypothesis *hyp) override; private: class Impl; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 1a0cf760..dc599cec 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( // add log_prob of each hypothesis to p_logprob before taking top_k for (int32_t i = 0; i != num_hyps; ++i) { - float log_prob = prev[i].log_prob; + float log_prob = prev[i].log_prob + prev[i].lm_log_prob; for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { *p_logprob += log_prob; } @@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( int32_t new_token = k % vocab_size; Hypothesis new_hyp = prev[hyp_index]; + const float prev_lm_log_prob = new_hyp.lm_log_prob; if (new_token != 0) { new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t + frame_offset); new_hyp.num_trailing_blanks = 0; + if (lm_) { + lm_->ComputeLMScore(lm_scale_, &new_hyp); + } } else { ++new_hyp.num_trailing_blanks; } - new_hyp.log_prob = p_logprob[k]; + new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob; hyps.Add(std::move(new_hyp)); } // for (auto k : topk) cur.push_back(std::move(hyps)); @@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } // for (int32_t b = 0; b != batch_size; ++b) } - if (lm_) { - lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); - } - for (int32_t b = 0; b != batch_size; ++b) { auto &hyps = cur[b]; auto best_hyp = hyps.GetMostProbable(true); diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 883f5afd..99ca4416 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { return *this; } +std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v)); + } + + return ans; +} + +std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v.value)); + } + + return ans; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 113fa5de..34ebc92e 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -97,8 +97,8 @@ struct CopyableOrtValue { CopyableOrtValue() = default; - /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT - : value(std::move(v)) {} + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT + : value(std::move(v)) {} CopyableOrtValue(const CopyableOrtValue &other); @@ -109,6 +109,10 @@ struct CopyableOrtValue { CopyableOrtValue &operator=(CopyableOrtValue &&other); }; +std::vector Convert(std::vector values); + +std::vector Convert(std::vector values); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 5cb6ca3e..0fca4a25 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -94,7 +94,7 @@ for a list of pre-trained models to download. auto s = recognizer.CreateStream(); s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); - std::vector tail_paddings(static_cast(0.2 * sampling_rate)); + std::vector tail_paddings(static_cast(0.5 * sampling_rate)); // Note: We can call AcceptWaveform() multiple times. s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());