Re-implement LM rescore for online transducer (#1231)
Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -27,9 +28,10 @@ class OnlineRnnLM::Impl {
|
||||
Init(config);
|
||||
}
|
||||
|
||||
void ComputeLMScore(float scale, Hypothesis *hyp) {
|
||||
// shallow fusion scoring function
|
||||
void ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
||||
if (hyp->nn_lm_states.empty()) {
|
||||
auto init_states = GetInitStates();
|
||||
auto init_states = GetInitStatesSF();
|
||||
hyp->nn_lm_scores.value = std::move(init_states.first);
|
||||
hyp->nn_lm_states = Convert(std::move(init_states.second));
|
||||
}
|
||||
@@ -49,6 +51,52 @@ class OnlineRnnLM::Impl {
|
||||
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
||||
}
|
||||
|
||||
// classic rescore function
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
for (auto &hyp : *hyps) {
|
||||
for (auto &h_m : hyp) {
|
||||
auto &h = h_m.second;
|
||||
auto &ys = h.ys;
|
||||
const int32_t token_num_in_chunk =
|
||||
ys.size() - context_size - h.cur_scored_pos - 1;
|
||||
|
||||
if (token_num_in_chunk < 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (h.nn_lm_states.empty()) {
|
||||
h.nn_lm_states = Convert(GetInitStates());
|
||||
}
|
||||
|
||||
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
||||
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator, x_shape.data(), x_shape.size());
|
||||
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
||||
std::copy(ys.begin() + context_size + h.cur_scored_pos,
|
||||
ys.end() - 1, p_x);
|
||||
|
||||
// streaming forward by NN LM
|
||||
auto out = ScoreToken(std::move(x),
|
||||
Convert(std::move(h.nn_lm_states)));
|
||||
|
||||
// update NN LM score in hyp
|
||||
const float *p_nll = out.first.GetTensorData<float>();
|
||||
h.lm_log_prob = -scale * (*p_nll);
|
||||
|
||||
// update NN LM states in hyp
|
||||
h.nn_lm_states = Convert(std::move(out.second));
|
||||
|
||||
h.cur_scored_pos += token_num_in_chunk;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||
Ort::Value x, std::vector<Ort::Value> states) {
|
||||
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
|
||||
@@ -66,7 +114,8 @@ class OnlineRnnLM::Impl {
|
||||
return {std::move(out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
|
||||
// get init states for shallow fusion
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(init_states_.size());
|
||||
for (auto &s : init_states_) {
|
||||
@@ -75,6 +124,18 @@ class OnlineRnnLM::Impl {
|
||||
return {View(&init_scores_.value), std::move(ans)};
|
||||
}
|
||||
|
||||
// get init states for classic rescore
|
||||
std::vector<Ort::Value> GetInitStates() const {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(init_states_.size());
|
||||
|
||||
for (const auto &s : init_states_) {
|
||||
ans.emplace_back(Clone(allocator_, &s));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(const OnlineLMConfig &config) {
|
||||
auto buf = ReadFile(config_.model);
|
||||
@@ -116,7 +177,8 @@ class OnlineRnnLM::Impl {
|
||||
states.push_back(std::move(c));
|
||||
auto pair = ScoreToken(std::move(x), std::move(states));
|
||||
|
||||
init_scores_.value = std::move(pair.first);
|
||||
init_scores_.value = std::move(pair.first); // only used during
|
||||
// shallow fusion
|
||||
init_states_ = std::move(pair.second);
|
||||
}
|
||||
|
||||
@@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
||||
|
||||
OnlineRnnLM::~OnlineRnnLM() = default;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
|
||||
// classic rescore state init
|
||||
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
|
||||
return impl_->GetInitStates();
|
||||
}
|
||||
|
||||
// shallow fusion state init
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() {
|
||||
return impl_->GetInitStatesSF();
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
|
||||
Ort::Value x, std::vector<Ort::Value> states) {
|
||||
return impl_->ScoreToken(std::move(x), std::move(states));
|
||||
}
|
||||
|
||||
void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
|
||||
return impl_->ComputeLMScore(scale, hyp);
|
||||
// classic rescore scores
|
||||
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) {
|
||||
return impl_->ComputeLMScore(scale, context_size, hyps);
|
||||
}
|
||||
|
||||
// shallow fusion scores
|
||||
void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
||||
return impl_->ComputeLMScoreSF(scale, hyp);
|
||||
}
|
||||
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user