add shallow fusion (#147)
This commit is contained in:
@@ -34,13 +34,11 @@ struct Hypothesis {
|
|||||||
// LM log prob if any.
|
// LM log prob if any.
|
||||||
double lm_log_prob = 0;
|
double lm_log_prob = 0;
|
||||||
|
|
||||||
int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM
|
// the nn lm score for next token given the current ys
|
||||||
|
CopyableOrtValue nn_lm_scores;
|
||||||
|
// the nn lm states
|
||||||
std::vector<CopyableOrtValue> nn_lm_states;
|
std::vector<CopyableOrtValue> nn_lm_states;
|
||||||
|
|
||||||
// TODO(fangjun): Make it configurable
|
|
||||||
// the minimum of tokens in a chunk for streaming RNN LM
|
|
||||||
int32_t lm_rescore_min_chunk = 2; // a const
|
|
||||||
|
|
||||||
int32_t num_trailing_blanks = 0;
|
int32_t num_trailing_blanks = 0;
|
||||||
|
|
||||||
Hypothesis() = default;
|
Hypothesis() = default;
|
||||||
|
|||||||
@@ -13,80 +13,8 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
|
|
||||||
std::vector<CopyableOrtValue> ans;
|
|
||||||
ans.reserve(values.size());
|
|
||||||
|
|
||||||
for (auto &v : values) {
|
|
||||||
ans.emplace_back(std::move(v));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
|
|
||||||
std::vector<Ort::Value> ans;
|
|
||||||
ans.reserve(values.size());
|
|
||||||
|
|
||||||
for (auto &v : values) {
|
|
||||||
ans.emplace_back(std::move(v.value));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
|
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
|
||||||
return std::make_unique<OnlineRnnLM>(config);
|
return std::make_unique<OnlineRnnLM>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineLM::ComputeLMScore(float scale, int32_t context_size,
|
|
||||||
std::vector<Hypotheses> *hyps) {
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
|
|
||||||
for (auto &hyp : *hyps) {
|
|
||||||
for (auto &h_m : hyp) {
|
|
||||||
auto &h = h_m.second;
|
|
||||||
auto &ys = h.ys;
|
|
||||||
const int32_t token_num_in_chunk =
|
|
||||||
ys.size() - context_size - h.cur_scored_pos - 1;
|
|
||||||
|
|
||||||
if (token_num_in_chunk < 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (h.nn_lm_states.empty()) {
|
|
||||||
h.nn_lm_states = Convert(GetInitStates());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
|
||||||
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
|
||||||
// shape of x and y are same
|
|
||||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
|
||||||
allocator, x_shape.data(), x_shape.size());
|
|
||||||
Ort::Value y = Ort::Value::CreateTensor<int64_t>(
|
|
||||||
allocator, x_shape.data(), x_shape.size());
|
|
||||||
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
|
||||||
int64_t *p_y = y.GetTensorMutableData<int64_t>();
|
|
||||||
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
|
|
||||||
p_x);
|
|
||||||
std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(),
|
|
||||||
p_y);
|
|
||||||
|
|
||||||
// streaming forward by NN LM
|
|
||||||
auto out = Rescore(std::move(x), std::move(y),
|
|
||||||
Convert(std::move(h.nn_lm_states)));
|
|
||||||
|
|
||||||
// update NN LM score in hyp
|
|
||||||
const float *p_nll = out.first.GetTensorData<float>();
|
|
||||||
h.lm_log_prob = -scale * (*p_nll);
|
|
||||||
|
|
||||||
// update NN LM states in hyp
|
|
||||||
h.nn_lm_states = Convert(std::move(out.second));
|
|
||||||
|
|
||||||
h.cur_scored_pos += token_num_in_chunk;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -21,29 +21,27 @@ class OnlineLM {
|
|||||||
|
|
||||||
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
||||||
|
|
||||||
virtual std::vector<Ort::Value> GetInitStates() = 0;
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
|
||||||
|
|
||||||
/** Rescore a batch of sentences.
|
/** ScoreToken a batch of sentences.
|
||||||
*
|
*
|
||||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
* @param x A 2-D tensor of shape (N, 1) with data type int64.
|
||||||
* @param y A 2-D tensor of shape (N, L) with data type int64.
|
|
||||||
* @param states It contains the states for the LM model
|
* @param states It contains the states for the LM model
|
||||||
* @return Return a pair containingo
|
* @return Return a pair containingo
|
||||||
* - negative loglike
|
* - log_prob of NN LM
|
||||||
* - updated states
|
* - updated states
|
||||||
*
|
*
|
||||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
|
||||||
*/
|
*/
|
||||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
|
Ort::Value x, std::vector<Ort::Value> states) = 0;
|
||||||
|
|
||||||
// This function updates hyp.lm_lob_prob of hyps.
|
/** This function updates lm_lob_prob and nn_lm_scores of hyp
|
||||||
//
|
*
|
||||||
// @param scale LM score
|
* @param scale LM score
|
||||||
// @param context_size Context size of the transducer decoder model
|
* @param hyps It is changed in-place.
|
||||||
// @param hyps It is changed in-place.
|
*
|
||||||
void ComputeLMScore(float scale, int32_t context_size,
|
*/
|
||||||
std::vector<Hypotheses> *hyps);
|
virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -26,10 +26,33 @@ class OnlineRnnLM::Impl {
|
|||||||
Init(config);
|
Init(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
void ComputeLMScore(float scale, Hypothesis *hyp) {
|
||||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
|
if (hyp->nn_lm_states.empty()) {
|
||||||
std::array<Ort::Value, 4> inputs = {
|
auto init_states = GetInitStates();
|
||||||
std::move(x), std::move(y), std::move(states[0]), std::move(states[1])};
|
hyp->nn_lm_scores.value = std::move(init_states.first);
|
||||||
|
hyp->nn_lm_states = Convert(std::move(init_states.second));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
|
||||||
|
const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
|
||||||
|
hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale;
|
||||||
|
|
||||||
|
// get lm scores for next tokens given the hyp->ys[:] and save to
|
||||||
|
// nn_lm_scores
|
||||||
|
std::array<int64_t, 2> x_shape{1, 1};
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
||||||
|
x_shape.size());
|
||||||
|
*x.GetTensorMutableData<int64_t>() = hyp->ys.back();
|
||||||
|
auto lm_out =
|
||||||
|
ScoreToken(std::move(x), Convert(hyp->nn_lm_states));
|
||||||
|
hyp->nn_lm_scores.value = std::move(lm_out.first);
|
||||||
|
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
|
Ort::Value x, std::vector<Ort::Value> states) {
|
||||||
|
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
|
||||||
|
std::move(states[1])};
|
||||||
|
|
||||||
auto out =
|
auto out =
|
||||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
@@ -43,15 +66,13 @@ class OnlineRnnLM::Impl {
|
|||||||
return {std::move(out[0]), std::move(next_states)};
|
return {std::move(out[0]), std::move(next_states)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> GetInitStates() const {
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(init_states_.size());
|
ans.reserve(init_states_.size());
|
||||||
|
|
||||||
for (const auto &s : init_states_) {
|
for (const auto &s : init_states_) {
|
||||||
ans.emplace_back(Clone(allocator_, &s));
|
ans.emplace_back(Clone(allocator_, &s));
|
||||||
}
|
}
|
||||||
|
return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)};
|
||||||
return ans;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -86,19 +107,16 @@ class OnlineRnnLM::Impl {
|
|||||||
Fill<float>(&h, 0);
|
Fill<float>(&h, 0);
|
||||||
Fill<float>(&c, 0);
|
Fill<float>(&c, 0);
|
||||||
std::array<int64_t, 2> x_shape{1, 1};
|
std::array<int64_t, 2> x_shape{1, 1};
|
||||||
// shape of x and y are same
|
|
||||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
||||||
x_shape.size());
|
x_shape.size());
|
||||||
Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
|
||||||
x_shape.size());
|
|
||||||
*x.GetTensorMutableData<int64_t>() = sos_id_;
|
*x.GetTensorMutableData<int64_t>() = sos_id_;
|
||||||
*y.GetTensorMutableData<int64_t>() = sos_id_;
|
|
||||||
|
|
||||||
std::vector<Ort::Value> states;
|
std::vector<Ort::Value> states;
|
||||||
states.push_back(std::move(h));
|
states.push_back(std::move(h));
|
||||||
states.push_back(std::move(c));
|
states.push_back(std::move(c));
|
||||||
auto pair = Rescore(std::move(x), std::move(y), std::move(states));
|
auto pair = ScoreToken(std::move(x), std::move(states));
|
||||||
|
|
||||||
|
init_scores_.value = std::move(pair.first);
|
||||||
init_states_ = std::move(pair.second);
|
init_states_ = std::move(pair.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,6 +134,7 @@ class OnlineRnnLM::Impl {
|
|||||||
std::vector<std::string> output_names_;
|
std::vector<std::string> output_names_;
|
||||||
std::vector<const char *> output_names_ptr_;
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
CopyableOrtValue init_scores_;
|
||||||
std::vector<Ort::Value> init_states_;
|
std::vector<Ort::Value> init_states_;
|
||||||
|
|
||||||
int32_t rnn_num_layers_ = 2;
|
int32_t rnn_num_layers_ = 2;
|
||||||
@@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
|||||||
|
|
||||||
OnlineRnnLM::~OnlineRnnLM() = default;
|
OnlineRnnLM::~OnlineRnnLM() = default;
|
||||||
|
|
||||||
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
|
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
|
||||||
return impl_->GetInitStates();
|
return impl_->GetInitStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore(
|
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
|
||||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
|
Ort::Value x, std::vector<Ort::Value> states) {
|
||||||
return impl_->Rescore(std::move(x), std::move(y), std::move(states));
|
return impl_->ScoreToken(std::move(x), std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
|
||||||
|
return impl_->ComputeLMScore(scale, hyp);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM {
|
|||||||
|
|
||||||
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
||||||
|
|
||||||
std::vector<Ort::Value> GetInitStates() override;
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
|
||||||
|
|
||||||
/** Rescore a batch of sentences.
|
/** ScoreToken a batch of sentences.
|
||||||
*
|
*
|
||||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||||
* @param y A 2-D tensor of shape (N, L) with data type int64.
|
|
||||||
* @param states It contains the states for the LM model
|
* @param states It contains the states for the LM model
|
||||||
* @return Return a pair containingo
|
* @return Return a pair containingo
|
||||||
* - negative loglike
|
* - log_prob of NN LM
|
||||||
* - updated states
|
* - updated states
|
||||||
*
|
*
|
||||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
|
||||||
*/
|
*/
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override;
|
Ort::Value x, std::vector<Ort::Value> states) override;
|
||||||
|
|
||||||
|
/** This function updates lm_lob_prob and nn_lm_scores of hyp
|
||||||
|
*
|
||||||
|
* @param scale LM score
|
||||||
|
* @param hyps It is changed in-place.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void ComputeLMScore(float scale, Hypothesis *hyp) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
|
|
||||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||||
float log_prob = prev[i].log_prob;
|
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
|
||||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||||
*p_logprob += log_prob;
|
*p_logprob += log_prob;
|
||||||
}
|
}
|
||||||
@@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
int32_t new_token = k % vocab_size;
|
int32_t new_token = k % vocab_size;
|
||||||
|
|
||||||
Hypothesis new_hyp = prev[hyp_index];
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
|
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||||
if (new_token != 0) {
|
if (new_token != 0) {
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
new_hyp.timestamps.push_back(t + frame_offset);
|
new_hyp.timestamps.push_back(t + frame_offset);
|
||||||
new_hyp.num_trailing_blanks = 0;
|
new_hyp.num_trailing_blanks = 0;
|
||||||
|
if (lm_) {
|
||||||
|
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
++new_hyp.num_trailing_blanks;
|
++new_hyp.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
new_hyp.log_prob = p_logprob[k];
|
new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob;
|
||||||
hyps.Add(std::move(new_hyp));
|
hyps.Add(std::move(new_hyp));
|
||||||
} // for (auto k : topk)
|
} // for (auto k : topk)
|
||||||
cur.push_back(std::move(hyps));
|
cur.push_back(std::move(hyps));
|
||||||
@@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lm_) {
|
|
||||||
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
auto &hyps = cur[b];
|
auto &hyps = cur[b];
|
||||||
auto best_hyp = hyps.GetMostProbable(true);
|
auto best_hyp = hyps.GetMostProbable(true);
|
||||||
|
|||||||
@@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
|
||||||
|
std::vector<CopyableOrtValue> ans;
|
||||||
|
ans.reserve(values.size());
|
||||||
|
|
||||||
|
for (auto &v : values) {
|
||||||
|
ans.emplace_back(std::move(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
|
||||||
|
std::vector<Ort::Value> ans;
|
||||||
|
ans.reserve(values.size());
|
||||||
|
|
||||||
|
for (auto &v : values) {
|
||||||
|
ans.emplace_back(std::move(v.value));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -109,6 +109,10 @@ struct CopyableOrtValue {
|
|||||||
CopyableOrtValue &operator=(CopyableOrtValue &&other);
|
CopyableOrtValue &operator=(CopyableOrtValue &&other);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values);
|
||||||
|
|
||||||
|
std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ for a list of pre-trained models to download.
|
|||||||
auto s = recognizer.CreateStream();
|
auto s = recognizer.CreateStream();
|
||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate));
|
std::vector<float> tail_paddings(static_cast<int>(0.5 * sampling_rate));
|
||||||
// Note: We can call AcceptWaveform() multiple times.
|
// Note: We can call AcceptWaveform() multiple times.
|
||||||
s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
|
s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user