Re-implement LM rescore for online transducer (#1231)
Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
This commit is contained in:
@@ -51,8 +51,13 @@ struct Hypothesis {
|
|||||||
// LM log prob if any.
|
// LM log prob if any.
|
||||||
double lm_log_prob = 0;
|
double lm_log_prob = 0;
|
||||||
|
|
||||||
// the nn lm score for next token given the current ys
|
// the nn lm score for next token given the current ys,
|
||||||
|
// when using shallow fusion
|
||||||
CopyableOrtValue nn_lm_scores;
|
CopyableOrtValue nn_lm_scores;
|
||||||
|
|
||||||
|
// cur scored tokens by RNN LM, when rescoring
|
||||||
|
int32_t cur_scored_pos = 0;
|
||||||
|
|
||||||
// the nn lm states
|
// the nn lm states
|
||||||
std::vector<CopyableOrtValue> nn_lm_states;
|
std::vector<CopyableOrtValue> nn_lm_states;
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) {
|
|||||||
"Number of threads to run the neural network of LM model");
|
"Number of threads to run the neural network of LM model");
|
||||||
po->Register("lm-provider", &lm_provider,
|
po->Register("lm-provider", &lm_provider,
|
||||||
"Specify a provider to LM model use: cpu, cuda, coreml");
|
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||||
|
po->Register("lm-shallow-fusion", &shallow_fusion,
|
||||||
|
"Boolean whether to use shallow fusion or rescore.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OnlineLMConfig::Validate() const {
|
bool OnlineLMConfig::Validate() const {
|
||||||
@@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const {
|
|||||||
|
|
||||||
os << "OnlineLMConfig(";
|
os << "OnlineLMConfig(";
|
||||||
os << "model=\"" << model << "\", ";
|
os << "model=\"" << model << "\", ";
|
||||||
os << "scale=" << scale << ")";
|
os << "scale=" << scale << ", ";
|
||||||
|
os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,15 +18,18 @@ struct OnlineLMConfig {
|
|||||||
float scale = 0.5;
|
float scale = 0.5;
|
||||||
int32_t lm_num_threads = 1;
|
int32_t lm_num_threads = 1;
|
||||||
std::string lm_provider = "cpu";
|
std::string lm_provider = "cpu";
|
||||||
|
// enable shallow fusion
|
||||||
|
bool shallow_fusion = true;
|
||||||
|
|
||||||
OnlineLMConfig() = default;
|
OnlineLMConfig() = default;
|
||||||
|
|
||||||
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||||
const std::string &lm_provider)
|
const std::string &lm_provider, bool shallow_fusion)
|
||||||
: model(model),
|
: model(model),
|
||||||
scale(scale),
|
scale(scale),
|
||||||
lm_num_threads(lm_num_threads),
|
lm_num_threads(lm_num_threads),
|
||||||
lm_provider(lm_provider) {}
|
lm_provider(lm_provider),
|
||||||
|
shallow_fusion(shallow_fusion) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -21,13 +21,17 @@ class OnlineLM {
|
|||||||
|
|
||||||
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
||||||
|
|
||||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
|
// init states for classic rescore
|
||||||
|
virtual std::vector<Ort::Value> GetInitStates() = 0;
|
||||||
|
|
||||||
/** ScoreToken a batch of sentences.
|
// init states for shallow fusion
|
||||||
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() = 0;
|
||||||
|
|
||||||
|
/** ScoreToken a batch of sentences (shallow fusion).
|
||||||
*
|
*
|
||||||
* @param x A 2-D tensor of shape (N, 1) with data type int64.
|
* @param x A 2-D tensor of shape (N, 1) with data type int64.
|
||||||
* @param states It contains the states for the LM model
|
* @param states It contains the states for the LM model
|
||||||
* @return Return a pair containingo
|
* @return Return a pair containing
|
||||||
* - log_prob of NN LM
|
* - log_prob of NN LM
|
||||||
* - updated states
|
* - updated states
|
||||||
*
|
*
|
||||||
@@ -35,13 +39,23 @@ class OnlineLM {
|
|||||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, std::vector<Ort::Value> states) = 0;
|
Ort::Value x, std::vector<Ort::Value> states) = 0;
|
||||||
|
|
||||||
/** This function updates lm_lob_prob and nn_lm_scores of hyp
|
/** This function updates hyp.lm_log_prob of hyps (classic rescore).
|
||||||
|
*
|
||||||
|
* @param scale LM score
|
||||||
|
* @param context_size Context size of the transducer decoder model
|
||||||
|
* @param hyps It is changed in-place.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
virtual void ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps) = 0;
|
||||||
|
|
||||||
|
/** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion).
|
||||||
*
|
*
|
||||||
* @param scale LM score
|
* @param scale LM score
|
||||||
* @param hyps It is changed in-place.
|
* @param hyps It is changed in-place.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0;
|
virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
@@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -27,9 +28,10 @@ class OnlineRnnLM::Impl {
|
|||||||
Init(config);
|
Init(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputeLMScore(float scale, Hypothesis *hyp) {
|
// shallow fusion scoring function
|
||||||
|
void ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
||||||
if (hyp->nn_lm_states.empty()) {
|
if (hyp->nn_lm_states.empty()) {
|
||||||
auto init_states = GetInitStates();
|
auto init_states = GetInitStatesSF();
|
||||||
hyp->nn_lm_scores.value = std::move(init_states.first);
|
hyp->nn_lm_scores.value = std::move(init_states.first);
|
||||||
hyp->nn_lm_states = Convert(std::move(init_states.second));
|
hyp->nn_lm_states = Convert(std::move(init_states.second));
|
||||||
}
|
}
|
||||||
@@ -49,6 +51,52 @@ class OnlineRnnLM::Impl {
|
|||||||
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// classic rescore function
|
||||||
|
void ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
|
||||||
|
for (auto &hyp : *hyps) {
|
||||||
|
for (auto &h_m : hyp) {
|
||||||
|
auto &h = h_m.second;
|
||||||
|
auto &ys = h.ys;
|
||||||
|
const int32_t token_num_in_chunk =
|
||||||
|
ys.size() - context_size - h.cur_scored_pos - 1;
|
||||||
|
|
||||||
|
if (token_num_in_chunk < 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (h.nn_lm_states.empty()) {
|
||||||
|
h.nn_lm_states = Convert(GetInitStates());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
||||||
|
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
allocator, x_shape.data(), x_shape.size());
|
||||||
|
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
||||||
|
std::copy(ys.begin() + context_size + h.cur_scored_pos,
|
||||||
|
ys.end() - 1, p_x);
|
||||||
|
|
||||||
|
// streaming forward by NN LM
|
||||||
|
auto out = ScoreToken(std::move(x),
|
||||||
|
Convert(std::move(h.nn_lm_states)));
|
||||||
|
|
||||||
|
// update NN LM score in hyp
|
||||||
|
const float *p_nll = out.first.GetTensorData<float>();
|
||||||
|
h.lm_log_prob = -scale * (*p_nll);
|
||||||
|
|
||||||
|
// update NN LM states in hyp
|
||||||
|
h.nn_lm_states = Convert(std::move(out.second));
|
||||||
|
|
||||||
|
h.cur_scored_pos += token_num_in_chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, std::vector<Ort::Value> states) {
|
Ort::Value x, std::vector<Ort::Value> states) {
|
||||||
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
|
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
|
||||||
@@ -66,7 +114,8 @@ class OnlineRnnLM::Impl {
|
|||||||
return {std::move(out[0]), std::move(next_states)};
|
return {std::move(out[0]), std::move(next_states)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
|
// get init states for shallow fusion
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(init_states_.size());
|
ans.reserve(init_states_.size());
|
||||||
for (auto &s : init_states_) {
|
for (auto &s : init_states_) {
|
||||||
@@ -75,6 +124,18 @@ class OnlineRnnLM::Impl {
|
|||||||
return {View(&init_scores_.value), std::move(ans)};
|
return {View(&init_scores_.value), std::move(ans)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get init states for classic rescore
|
||||||
|
std::vector<Ort::Value> GetInitStates() const {
|
||||||
|
std::vector<Ort::Value> ans;
|
||||||
|
ans.reserve(init_states_.size());
|
||||||
|
|
||||||
|
for (const auto &s : init_states_) {
|
||||||
|
ans.emplace_back(Clone(allocator_, &s));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(const OnlineLMConfig &config) {
|
void Init(const OnlineLMConfig &config) {
|
||||||
auto buf = ReadFile(config_.model);
|
auto buf = ReadFile(config_.model);
|
||||||
@@ -116,7 +177,8 @@ class OnlineRnnLM::Impl {
|
|||||||
states.push_back(std::move(c));
|
states.push_back(std::move(c));
|
||||||
auto pair = ScoreToken(std::move(x), std::move(states));
|
auto pair = ScoreToken(std::move(x), std::move(states));
|
||||||
|
|
||||||
init_scores_.value = std::move(pair.first);
|
init_scores_.value = std::move(pair.first); // only used during
|
||||||
|
// shallow fusion
|
||||||
init_states_ = std::move(pair.second);
|
init_states_ = std::move(pair.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
|||||||
|
|
||||||
OnlineRnnLM::~OnlineRnnLM() = default;
|
OnlineRnnLM::~OnlineRnnLM() = default;
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
|
// classic rescore state init
|
||||||
|
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
|
||||||
return impl_->GetInitStates();
|
return impl_->GetInitStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shallow fusion state init
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() {
|
||||||
|
return impl_->GetInitStatesSF();
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
|
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
|
||||||
Ort::Value x, std::vector<Ort::Value> states) {
|
Ort::Value x, std::vector<Ort::Value> states) {
|
||||||
return impl_->ScoreToken(std::move(x), std::move(states));
|
return impl_->ScoreToken(std::move(x), std::move(states));
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
|
// classic rescore scores
|
||||||
return impl_->ComputeLMScore(scale, hyp);
|
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps) {
|
||||||
|
return impl_->ComputeLMScore(scale, context_size, hyps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shallow fusion scores
|
||||||
|
void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
||||||
|
return impl_->ComputeLMScoreSF(scale, hyp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM {
|
|||||||
|
|
||||||
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
|
// init scores for classic rescore
|
||||||
|
std::vector<Ort::Value> GetInitStates() override;
|
||||||
|
|
||||||
/** ScoreToken a batch of sentences.
|
// init scores for shallow fusion
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override;
|
||||||
|
|
||||||
|
/** ScoreToken a batch of sentences (shallow fusion).
|
||||||
*
|
*
|
||||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||||
* @param states It contains the states for the LM model
|
* @param states It contains the states for the LM model
|
||||||
* @return Return a pair containingo
|
* @return Return a pair containing
|
||||||
* - log_prob of NN LM
|
* - log_prob of NN LM
|
||||||
* - updated states
|
* - updated states
|
||||||
*
|
*
|
||||||
@@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM {
|
|||||||
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, std::vector<Ort::Value> states) override;
|
Ort::Value x, std::vector<Ort::Value> states) override;
|
||||||
|
|
||||||
/** This function updates lm_lob_prob and nn_lm_scores of hyp
|
/** This function updates hyp.lm_lob_prob of hyps (classic rescore).
|
||||||
|
*
|
||||||
|
* @param scale LM score
|
||||||
|
* @param context_size Context size of the transducer decoder model
|
||||||
|
* @param hyps It is changed in-place.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps) override;
|
||||||
|
|
||||||
|
/** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion).
|
||||||
*
|
*
|
||||||
* @param scale LM score
|
* @param scale LM score
|
||||||
* @param hyps It is changed in-place.
|
* @param hyps It is changed in-place.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
void ComputeLMScore(float scale, Hypothesis *hyp) override;
|
void ComputeLMScoreSF(float scale, Hypothesis *hyp) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|||||||
@@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
|
|
||||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||||
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
|
float log_prob = prev[i].log_prob;
|
||||||
|
if (lm_ && shallow_fusion_) {
|
||||||
|
log_prob += prev[i].lm_log_prob;
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||||
*p_logprob += log_prob;
|
*p_logprob += log_prob;
|
||||||
}
|
}
|
||||||
@@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
context_score = std::get<0>(context_res);
|
context_score = std::get<0>(context_res);
|
||||||
new_hyp.context_state = std::get<1>(context_res);
|
new_hyp.context_state = std::get<1>(context_res);
|
||||||
}
|
}
|
||||||
if (lm_) {
|
if (lm_ && shallow_fusion_) {
|
||||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
++new_hyp.num_trailing_blanks;
|
++new_hyp.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
|
if (lm_ && shallow_fusion_) {
|
||||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||||
prev_lm_log_prob; // log_prob only includes the
|
prev_lm_log_prob; // log_prob only includes the
|
||||||
// score of the transducer
|
// score of the transducer
|
||||||
|
} else {
|
||||||
|
new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
|
||||||
|
// previous token
|
||||||
|
// score is ignored
|
||||||
|
}
|
||||||
|
|
||||||
// export the per-token log scores
|
// export the per-token log scores
|
||||||
if (new_token != 0 && new_token != unk_id_) {
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
float y_prob = logit_with_temperature[start * vocab_size + k];
|
float y_prob = logit_with_temperature[start * vocab_size + k];
|
||||||
new_hyp.ys_probs.push_back(y_prob);
|
new_hyp.ys_probs.push_back(y_prob);
|
||||||
|
|
||||||
if (lm_) { // export only when LM is used
|
if (lm_ && shallow_fusion_) { // export only if
|
||||||
|
// LM shallow fusion is used
|
||||||
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
||||||
|
|
||||||
if (lm_scale_ != 0.0) {
|
if (lm_scale_ != 0.0) {
|
||||||
lm_prob /= lm_scale_; // remove lm-scale
|
lm_prob /= lm_scale_; // remove lm-scale
|
||||||
}
|
}
|
||||||
@@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||||
|
|
||||||
|
// classic lm rescore
|
||||||
|
if (lm_ && !shallow_fusion_) {
|
||||||
|
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
auto &hyps = cur[b];
|
auto &hyps = cur[b];
|
||||||
auto best_hyp = hyps.GetMostProbable(true);
|
auto best_hyp = hyps.GetMostProbable(true);
|
||||||
|
|||||||
@@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
||||||
OnlineLM *lm,
|
OnlineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale, int32_t unk_id,
|
float lm_scale,
|
||||||
|
bool shallow_fusion,
|
||||||
|
int32_t unk_id,
|
||||||
float blank_penalty,
|
float blank_penalty,
|
||||||
float temperature_scale)
|
float temperature_scale)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale),
|
lm_scale_(lm_scale),
|
||||||
|
shallow_fusion_(shallow_fusion),
|
||||||
unk_id_(unk_id),
|
unk_id_(unk_id),
|
||||||
blank_penalty_(blank_penalty),
|
blank_penalty_(blank_penalty),
|
||||||
temperature_scale_(temperature_scale) {}
|
temperature_scale_(temperature_scale) {}
|
||||||
@@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
|
|
||||||
int32_t max_active_paths_;
|
int32_t max_active_paths_;
|
||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
|
bool shallow_fusion_; // used only when lm_ is not nullptr
|
||||||
int32_t unk_id_;
|
int32_t unk_id_;
|
||||||
float blank_penalty_;
|
float blank_penalty_;
|
||||||
float temperature_scale_;
|
float temperature_scale_;
|
||||||
|
|||||||
@@ -13,13 +13,16 @@ namespace sherpa_onnx {
|
|||||||
void PybindOnlineLMConfig(py::module *m) {
|
void PybindOnlineLMConfig(py::module *m) {
|
||||||
using PyClass = OnlineLMConfig;
|
using PyClass = OnlineLMConfig;
|
||||||
py::class_<PyClass>(*m, "OnlineLMConfig")
|
py::class_<PyClass>(*m, "OnlineLMConfig")
|
||||||
.def(py::init<const std::string &, float, int32_t, const std::string &>(),
|
.def(py::init<const std::string &, float, int32_t,
|
||||||
|
const std::string &, bool>(),
|
||||||
py::arg("model") = "", py::arg("scale") = 0.5f,
|
py::arg("model") = "", py::arg("scale") = 0.5f,
|
||||||
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu")
|
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
|
||||||
|
py::arg("shallow_fusion") = true)
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
.def_readwrite("scale", &PyClass::scale)
|
.def_readwrite("scale", &PyClass::scale)
|
||||||
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
||||||
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
||||||
|
.def_readwrite("shallow_fusion", &PyClass::shallow_fusion)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class OnlineRecognizer(object):
|
|||||||
bpe_vocab: str = "",
|
bpe_vocab: str = "",
|
||||||
lm: str = "",
|
lm: str = "",
|
||||||
lm_scale: float = 0.1,
|
lm_scale: float = 0.1,
|
||||||
|
lm_shallow_fusion: bool = True,
|
||||||
temperature_scale: float = 2.0,
|
temperature_scale: float = 2.0,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
@@ -274,6 +275,7 @@ class OnlineRecognizer(object):
|
|||||||
lm_config = OnlineLMConfig(
|
lm_config = OnlineLMConfig(
|
||||||
model=lm,
|
model=lm,
|
||||||
scale=lm_scale,
|
scale=lm_scale,
|
||||||
|
shallow_fusion=lm_shallow_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
recognizer_config = OnlineRecognizerConfig(
|
recognizer_config = OnlineRecognizerConfig(
|
||||||
|
|||||||
Reference in New Issue
Block a user