Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
This commit is contained in:
Wei Kang
2023-06-30 16:46:24 +08:00
committed by GitHub
parent b2e0c4c9c2
commit 513dfaa552
10 changed files with 238 additions and 22 deletions

View File

@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
// used only for modified_beam_search
int32_t max_active_paths = 4;
/// used only for modified_beam_search
float context_score = 1.5;
OnlineRecognizerConfig() = default;
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
int32_t max_active_paths, float context_score)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po);
bool Validate() const;
@@ -112,6 +116,10 @@ class OnlineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const;
// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise