Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("context-score", &context_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
if (config_.decoding_method == "modified_beam_search" &&
|
||||
nullptr != stream->GetContextGraph()) {
|
||||
// r.hyps has only one element.
|
||||
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||
it->second.context_state = stream->GetContextGraph()->Root();
|
||||
}
|
||||
}
|
||||
|
||||
stream->SetResult(r);
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &contexts) const {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(contexts, config_.context_score);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
bool has_context_graph = false;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
if (!has_context_graph && ss[i]->GetContextGraph())
|
||||
has_context_graph = true;
|
||||
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
if (has_context_graph) {
|
||||
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||
} else {
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user