Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
This commit is contained in:
Wei Kang
2023-06-30 16:46:24 +08:00
committed by GitHub
parent b2e0c4c9c2
commit 513dfaa552
10 changed files with 238 additions and 22 deletions

View File

@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
}
#endif
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
std::unique_ptr<OnlineStream> CreateStream() const {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetResult(decoder_->GetEmptyResult());
stream->SetStates(model_->GetEncoderInitStates());
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
return stream;
}
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
bool has_context_graph = false;
for (int32_t i = 0; i != n; ++i) {
if (!has_context_graph && ss[i]->GetContextGraph())
has_context_graph = true;
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames));
decoder_->Decode(std::move(pair.first), &results);
if (has_context_graph) {
decoder_->Decode(std::move(pair.first), ss, &results);
} else {
decoder_->Decode(std::move(pair.first), &results);
}
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second);
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}