Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("context-score", &context_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
if (config_.decoding_method == "modified_beam_search" &&
|
||||
nullptr != stream->GetContextGraph()) {
|
||||
// r.hyps has only one element.
|
||||
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||
it->second.context_state = stream->GetContextGraph()->Root();
|
||||
}
|
||||
}
|
||||
|
||||
stream->SetResult(r);
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &contexts) const {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(contexts, config_.context_score);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
bool has_context_graph = false;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
if (!has_context_graph && ss[i]->GetContextGraph())
|
||||
has_context_graph = true;
|
||||
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
if (has_context_graph) {
|
||||
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||
} else {
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
|
||||
std::string decoding_method = "greedy_search";
|
||||
// now support modified_beam_search and greedy_search
|
||||
|
||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
||||
// used only for modified_beam_search
|
||||
int32_t max_active_paths = 4;
|
||||
/// used only for modified_beam_search
|
||||
float context_score = 1.5;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
int32_t max_active_paths, float context_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -112,6 +116,10 @@ class OnlineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
// Create a stream with context phrases
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
* Return false otherwise
|
||||
|
||||
@@ -13,8 +13,9 @@ namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config)
|
||||
: feat_extractor_(config) {}
|
||||
explicit Impl(const FeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: feat_extractor_(config), context_graph_(context_graph) {}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
|
||||
|
||||
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
ContextGraphPtr context_graph_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr */)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OnlineStream::~OnlineStream() = default;
|
||||
|
||||
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||
return impl_->GetStates();
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
|
||||
@@ -16,7 +17,8 @@ namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream {
|
||||
public:
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {});
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OnlineStream();
|
||||
|
||||
/**
|
||||
@@ -71,6 +73,13 @@ class OnlineStream {
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
/**
|
||||
* Get the context graph corresponding to this stream.
|
||||
*
|
||||
* @return Return the context graph for this stream.
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
|
||||
OnlineTransducerDecoderResult &&other);
|
||||
};
|
||||
|
||||
class OnlineStream;
|
||||
class OnlineTransducerDecoder {
|
||||
public:
|
||||
virtual ~OnlineTransducerDecoder() = default;
|
||||
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
|
||||
virtual void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||
|
||||
/** Run transducer beam search given the output from the encoder model.
|
||||
*
|
||||
* Note: Currently this interface is for contextual-biasing feature which
|
||||
* needs a ContextGraph owned by the OnlineStream.
|
||||
*
|
||||
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||
* @param ss A list of OnlineStreams.
|
||||
* @param result It is modified in-place.
|
||||
*
|
||||
* @note There is no need to pass encoder_out_length here since for the
|
||||
* online decoding case, each utterance has the same number of frames
|
||||
* and there are no paddings.
|
||||
*/
|
||||
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// used for endpointing. We need to keep decoder_out after reset
|
||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
Decode(std::move(encoder_out), nullptr, result);
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
|
||||
if (new_token != 0) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob =
|
||||
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
@@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
||||
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user