online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)
* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) - added `reset_encoder` boolean member into the OnlineRecognizerConfig class - by default the encoder is not reset * pybind11, adding empty symbols for disabled modules (tts, diarization) * reset_encoder, add default value (false) [pybind11]
This commit is contained in:
@@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
// reset encoder states
|
||||
// s->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
auto last_result = s->GetResult();
|
||||
// if last result is not empty, then
|
||||
// truncate all last hyps and save as the context for next result
|
||||
|
||||
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
|
||||
// if last result is not empty, then
|
||||
// truncate all last hyps and save as the 'ys' context for next result
|
||||
// (the encoder state buffers are kept)
|
||||
for (const auto &it : last_result.hyps) {
|
||||
auto h = it.second;
|
||||
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
|
||||
@@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
|
||||
last_result.tokens.end());
|
||||
} else {
|
||||
if(config_.reset_encoder) {
|
||||
// reset encoder states, use blanks as 'ys' context
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
}
|
||||
|
||||
// but reset all contextual biasing graph states to root
|
||||
|
||||
@@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"rule-fars", &rule_fars,
|
||||
"If not empty, it specifies fst archives for inverse text normalization. "
|
||||
"If there are multiple archives, they are separated by a comma.");
|
||||
|
||||
po->Register("reset-encoder", &reset_encoder,
|
||||
"True to reset encoder_state on an endpoint after empty segment."
|
||||
"Done in `Reset()` method, after an endpoint was detected.");
|
||||
}
|
||||
|
||||
bool OnlineRecognizerConfig::Validate() const {
|
||||
@@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "blank_penalty=" << blank_penalty << ", ";
|
||||
os << "temperature_scale=" << temperature_scale << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||
os << "rule_fars=\"" << rule_fars << "\")";
|
||||
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ struct OnlineRecognizerConfig {
|
||||
OnlineLMConfig lm_config;
|
||||
EndpointConfig endpoint_config;
|
||||
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
|
||||
bool enable_endpoint = true;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
@@ -101,6 +102,11 @@ struct OnlineRecognizerConfig {
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
|
||||
// True to reset encoder_state on an endpoint after empty segment.
|
||||
// Done in `Reset()` method, after an endpoint was detected,
|
||||
// currently only in `OnlineRecognizerTransducerImpl`.
|
||||
bool reset_encoder = false;
|
||||
|
||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||
/// the hotwords will be loaded from the buffered string instead of from the
|
||||
/// "hotwords_file"
|
||||
@@ -116,7 +122,8 @@ struct OnlineRecognizerConfig {
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty, float temperature_scale,
|
||||
const std::string &rule_fsts, const std::string &rule_fars)
|
||||
const std::string &rule_fsts, const std::string &rule_fars,
|
||||
bool reset_encoder)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -130,7 +137,8 @@ struct OnlineRecognizerConfig {
|
||||
blank_penalty(blank_penalty),
|
||||
temperature_scale(temperature_scale),
|
||||
rule_fsts(rule_fsts),
|
||||
rule_fars(rule_fars) {}
|
||||
rule_fars(rule_fars),
|
||||
reset_encoder(reset_encoder) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
Reference in New Issue
Block a user