Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
This commit is contained in:
Wei Kang
2023-06-30 16:46:24 +08:00
committed by GitHub
parent b2e0c4c9c2
commit 513dfaa552
10 changed files with 238 additions and 22 deletions

View File

@@ -9,6 +9,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
Decode(std::move(encoder_out), nullptr, result);
}
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index];
const float prev_lm_log_prob = new_hyp.lm_log_prob;
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob =
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)