diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 4ce51979..8370397b 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } } - // reset encoder states - // s->SetStates(model_->GetEncoderInitStates()); - auto r = decoder_->GetEmptyResult(); auto last_result = s->GetResult(); - // if last result is not empty, then - // truncate all last hyps and save as the context for next result + if (static_cast(last_result.tokens.size()) > context_size) { + // if last result is not empty, then + // truncate all last hyps and save as the 'ys' context for next result + // (the encoder state buffers are kept) for (const auto &it : last_result.hyps) { auto h = it.second; r.hyps.Add({std::vector(h.ys.end() - context_size, @@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { r.tokens = std::vector (last_result.tokens.end() - context_size, last_result.tokens.end()); + } else { + if(config_.reset_encoder) { + // reset encoder states, use blanks as 'ys' context + s->SetStates(model_->GetEncoderInitStates()); + } } // but reset all contextual biasing graph states to root diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 4ccc939d..a2c76d70 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "rule-fars", &rule_fars, "If not empty, it specifies fst archives for inverse text normalization. " "If there are multiple archives, they are separated by a comma."); + + po->Register("reset-encoder", &reset_encoder, + "True to reset encoder_state on an endpoint after empty segment." + "Done in `Reset()` method, after an endpoint was detected."); } bool OnlineRecognizerConfig::Validate() const { @@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "blank_penalty=" << blank_penalty << ", "; os << "temperature_scale=" << temperature_scale << ", "; os << "rule_fsts=\"" << rule_fsts << "\", "; - os << "rule_fars=\"" << rule_fars << "\")"; + os << "rule_fars=\"" << rule_fars << "\", "; + os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 8854fbd2..5936e0df 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -79,6 +79,7 @@ struct OnlineRecognizerConfig { OnlineLMConfig lm_config; EndpointConfig endpoint_config; OnlineCtcFstDecoderConfig ctc_fst_decoder_config; + bool enable_endpoint = true; std::string decoding_method = "greedy_search"; @@ -101,6 +102,11 @@ struct OnlineRecognizerConfig { // If there are multiple FST archives, they are applied from left to right. std::string rule_fars; + // True to reset encoder_state on an endpoint after empty segment. + // Done in `Reset()` method, after an endpoint was detected, + // currently only in `OnlineRecognizerTransducerImpl`. + bool reset_encoder = false; + /// used only for modified_beam_search, if hotwords_buf is non-empty, /// the hotwords will be loaded from the buffered string instead of from the /// "hotwords_file" @@ -116,7 +122,8 @@ struct OnlineRecognizerConfig { bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, float blank_penalty, float temperature_scale, - const std::string &rule_fsts, const std::string &rule_fars) + const std::string &rule_fsts, const std::string &rule_fars, + bool reset_encoder) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -130,7 +137,8 @@ struct OnlineRecognizerConfig { blank_penalty(blank_penalty), temperature_scale(temperature_scale), rule_fsts(rule_fsts), - rule_fars(rule_fars) {} + rule_fars(rule_fars), + reset_encoder(reset_encoder) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index fe6cd454..38d4e776 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { const OnlineLMConfig &, const EndpointConfig &, const OnlineCtcFstDecoderConfig &, bool, const std::string &, int32_t, const std::string &, float, - float, float, const std::string &, const std::string &>(), + float, float, const std::string &, const std::string &, bool>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), @@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", - py::arg("rule_fars") = "") + py::arg("rule_fars") = "", py::arg("reset_encoder") = false) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("temperature_scale", &PyClass::temperature_scale) .def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fars", &PyClass::rule_fars) + .def_readwrite("reset_encoder", &PyClass::reset_encoder) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 4552bdfa..c00b4644 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) { #if SHERPA_ONNX_ENABLE_TTS == 1 PybindOfflineTts(&m); +#else + /* Define "empty" TTS sybmbols */ + m.attr("OfflineTtsKokoroModelConfig") = py::none(); + m.attr("OfflineTtsMatchaModelConfig") = py::none(); + m.attr("OfflineTtsModelConfig") = py::none(); + m.attr("OfflineTtsVitsModelConfig") = py::none(); + m.attr("GeneratedAudio") = py::none(); + m.attr("OfflineTtsConfig") = py::none(); + m.attr("OfflineTts") = py::none(); #endif PybindSpeakerEmbeddingExtractor(&m); @@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindFastClustering(&m); PybindOfflineSpeakerDiarizationResult(&m); PybindOfflineSpeakerDiarization(&m); +#else + /* Define "empty" diarization sybmbols */ + m.attr("FastClusteringConfig") = py::none(); + m.attr("FastClustering") = py::none(); + m.attr("OfflineSpeakerDiarizationSegment") = py::none(); + m.attr("OfflineSpeakerDiarizationResult") = py::none(); + m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none(); + m.attr("OfflineSpeakerSegmentationModelConfig") = py::none(); + m.attr("OfflineSpeakerDiarizationConfig") = py::none(); + m.attr("OfflineSpeakerDiarization") = py::none(); #endif PybindAlsa(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 77831de7..78e383cc 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -68,6 +68,7 @@ class OnlineRecognizer(object): lm_scale: float = 0.1, lm_shallow_fusion: bool = True, temperature_scale: float = 2.0, + reset_encoder: bool = False, debug: bool = False, rule_fsts: str = "", rule_fars: str = "", @@ -162,6 +163,10 @@ class OnlineRecognizer(object): Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original logits without temperature. + reset_encoder: + True to reset `encoder_state` on an endpoint after empty segment. + Done in `Reset()` method, after an endpoint was detected, + currently only in `OnlineRecognizerTransducerImpl`. model_type: Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. @@ -305,6 +310,7 @@ class OnlineRecognizer(object): temperature_scale=temperature_scale, rule_fsts=rule_fsts, rule_fars=rule_fars, + reset_encoder=reset_encoder, ) self.recognizer = _Recognizer(recognizer_config)