Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
Decode(std::move(encoder_out), nullptr, result);
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
|
||||
if (new_token != 0) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob =
|
||||
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
|
||||
Reference in New Issue
Block a user