online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)

* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank)

- added `reset_encoder` boolean member into the OnlineRecognizerConfig class
- by default the encoder is not reset

* pybind11, adding empty symbols for disabled modules (tts, diarization)

* reset_encoder, add default value (false) [pybind11]
This commit is contained in:
Karel Vesely
2025-04-24 02:18:11 +02:00
committed by GitHub
parent 921c4370e6
commit 6a1efd8ac2
6 changed files with 53 additions and 10 deletions

View File

@@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &>(),
float, float, const std::string &, const std::string &, bool>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
@@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
.def("__str__", &PyClass::ToString);
}

View File

@@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
#if SHERPA_ONNX_ENABLE_TTS == 1
PybindOfflineTts(&m);
#else
/* Define "empty" TTS sybmbols */
m.attr("OfflineTtsKokoroModelConfig") = py::none();
m.attr("OfflineTtsMatchaModelConfig") = py::none();
m.attr("OfflineTtsModelConfig") = py::none();
m.attr("OfflineTtsVitsModelConfig") = py::none();
m.attr("GeneratedAudio") = py::none();
m.attr("OfflineTtsConfig") = py::none();
m.attr("OfflineTts") = py::none();
#endif
PybindSpeakerEmbeddingExtractor(&m);
@@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#else
/* Define "empty" diarization sybmbols */
m.attr("FastClusteringConfig") = py::none();
m.attr("FastClustering") = py::none();
m.attr("OfflineSpeakerDiarizationSegment") = py::none();
m.attr("OfflineSpeakerDiarizationResult") = py::none();
m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none();
m.attr("OfflineSpeakerSegmentationModelConfig") = py::none();
m.attr("OfflineSpeakerDiarizationConfig") = py::none();
m.attr("OfflineSpeakerDiarization") = py::none();
#endif
PybindAlsa(&m);