online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)

* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank)

- added `reset_encoder` boolean member into the OnlineRecognizerConfig class
- by default the encoder is not reset

* pybind11, adding empty symbols for disabled modules (tts, diarization)

* reset_encoder, add default value (false) [pybind11]
This commit is contained in:
Karel Vesely
2025-04-24 02:18:11 +02:00
committed by GitHub
parent 921c4370e6
commit 6a1efd8ac2
6 changed files with 53 additions and 10 deletions

View File

@@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
} }
} }
// reset encoder states
// s->SetStates(model_->GetEncoderInitStates());
auto r = decoder_->GetEmptyResult(); auto r = decoder_->GetEmptyResult();
auto last_result = s->GetResult(); auto last_result = s->GetResult();
// if last result is not empty, then
// truncate all last hyps and save as the context for next result
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
// if last result is not empty, then
// truncate all last hyps and save as the 'ys' context for next result
// (the encoder state buffers are kept)
for (const auto &it : last_result.hyps) { for (const auto &it : last_result.hyps) {
auto h = it.second; auto h = it.second;
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
@@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
last_result.tokens.end()); last_result.tokens.end());
} else {
if(config_.reset_encoder) {
// reset encoder states, use blanks as 'ys' context
s->SetStates(model_->GetEncoderInitStates());
}
} }
// but reset all contextual biasing graph states to root // but reset all contextual biasing graph states to root

View File

@@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"rule-fars", &rule_fars, "rule-fars", &rule_fars,
"If not empty, it specifies fst archives for inverse text normalization. " "If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma."); "If there are multiple archives, they are separated by a comma.");
po->Register("reset-encoder", &reset_encoder,
"True to reset encoder_state on an endpoint after empty segment."
"Done in `Reset()` method, after an endpoint was detected.");
} }
bool OnlineRecognizerConfig::Validate() const { bool OnlineRecognizerConfig::Validate() const {
@@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "blank_penalty=" << blank_penalty << ", "; os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ", "; os << "temperature_scale=" << temperature_scale << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", "; os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")"; os << "rule_fars=\"" << rule_fars << "\", ";
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
return os.str(); return os.str();
} }

View File

@@ -79,6 +79,7 @@ struct OnlineRecognizerConfig {
OnlineLMConfig lm_config; OnlineLMConfig lm_config;
EndpointConfig endpoint_config; EndpointConfig endpoint_config;
OnlineCtcFstDecoderConfig ctc_fst_decoder_config; OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
bool enable_endpoint = true; bool enable_endpoint = true;
std::string decoding_method = "greedy_search"; std::string decoding_method = "greedy_search";
@@ -101,6 +102,11 @@ struct OnlineRecognizerConfig {
// If there are multiple FST archives, they are applied from left to right. // If there are multiple FST archives, they are applied from left to right.
std::string rule_fars; std::string rule_fars;
// True to reset encoder_state on an endpoint after empty segment.
// Done in `Reset()` method, after an endpoint was detected,
// currently only in `OnlineRecognizerTransducerImpl`.
bool reset_encoder = false;
/// used only for modified_beam_search, if hotwords_buf is non-empty, /// used only for modified_beam_search, if hotwords_buf is non-empty,
/// the hotwords will be loaded from the buffered string instead of from the /// the hotwords will be loaded from the buffered string instead of from the
/// "hotwords_file" /// "hotwords_file"
@@ -116,7 +122,8 @@ struct OnlineRecognizerConfig {
bool enable_endpoint, const std::string &decoding_method, bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file, int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale, float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars) const std::string &rule_fsts, const std::string &rule_fars,
bool reset_encoder)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
lm_config(lm_config), lm_config(lm_config),
@@ -130,7 +137,8 @@ struct OnlineRecognizerConfig {
blank_penalty(blank_penalty), blank_penalty(blank_penalty),
temperature_scale(temperature_scale), temperature_scale(temperature_scale),
rule_fsts(rule_fsts), rule_fsts(rule_fsts),
rule_fars(rule_fars) {} rule_fars(rule_fars),
reset_encoder(reset_encoder) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
const OnlineLMConfig &, const EndpointConfig &, const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float, const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &>(), float, float, const std::string &, const std::string &, bool>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(), py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "") py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("lm_config", &PyClass::lm_config)
@@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("temperature_scale", &PyClass::temperature_scale) .def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars) .def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
#if SHERPA_ONNX_ENABLE_TTS == 1 #if SHERPA_ONNX_ENABLE_TTS == 1
PybindOfflineTts(&m); PybindOfflineTts(&m);
#else
/* Define "empty" TTS sybmbols */
m.attr("OfflineTtsKokoroModelConfig") = py::none();
m.attr("OfflineTtsMatchaModelConfig") = py::none();
m.attr("OfflineTtsModelConfig") = py::none();
m.attr("OfflineTtsVitsModelConfig") = py::none();
m.attr("GeneratedAudio") = py::none();
m.attr("OfflineTtsConfig") = py::none();
m.attr("OfflineTts") = py::none();
#endif #endif
PybindSpeakerEmbeddingExtractor(&m); PybindSpeakerEmbeddingExtractor(&m);
@@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFastClustering(&m); PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m); PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m); PybindOfflineSpeakerDiarization(&m);
#else
/* Define "empty" diarization sybmbols */
m.attr("FastClusteringConfig") = py::none();
m.attr("FastClustering") = py::none();
m.attr("OfflineSpeakerDiarizationSegment") = py::none();
m.attr("OfflineSpeakerDiarizationResult") = py::none();
m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none();
m.attr("OfflineSpeakerSegmentationModelConfig") = py::none();
m.attr("OfflineSpeakerDiarizationConfig") = py::none();
m.attr("OfflineSpeakerDiarization") = py::none();
#endif #endif
PybindAlsa(&m); PybindAlsa(&m);

View File

@@ -68,6 +68,7 @@ class OnlineRecognizer(object):
lm_scale: float = 0.1, lm_scale: float = 0.1,
lm_shallow_fusion: bool = True, lm_shallow_fusion: bool = True,
temperature_scale: float = 2.0, temperature_scale: float = 2.0,
reset_encoder: bool = False,
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
@@ -162,6 +163,10 @@ class OnlineRecognizer(object):
Temperature scaling for output symbol confidence estiamation. Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original It affects only confidence values, the decoding uses the original
logits without temperature. logits without temperature.
reset_encoder:
True to reset `encoder_state` on an endpoint after empty segment.
Done in `Reset()` method, after an endpoint was detected,
currently only in `OnlineRecognizerTransducerImpl`.
model_type: model_type:
Online transducer model type. Valid values are: conformer, lstm, Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice. zipformer, zipformer2. All other values lead to loading the model twice.
@@ -305,6 +310,7 @@ class OnlineRecognizer(object):
temperature_scale=temperature_scale, temperature_scale=temperature_scale,
rule_fsts=rule_fsts, rule_fsts=rule_fsts,
rule_fars=rule_fars, rule_fars=rule_fars,
reset_encoder=reset_encoder,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)