add blank_penalty for online transducer (#548)

This commit is contained in:
chiiyeh
2024-01-26 12:12:13 +08:00
committed by GitHub
parent 466a6855c8
commit e7b18a2139
13 changed files with 94 additions and 14 deletions

View File

@@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &, bool,
const std::string &, int32_t, const std::string &, float>(),
const std::string &, int32_t, const std::string &, float,
float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0)
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
@@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def("__str__", &PyClass::ToString);
}

View File

@@ -48,6 +48,7 @@ class OnlineRecognizer(object):
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
@@ -100,6 +101,8 @@ class OnlineRecognizer(object):
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
blank_penalty:
The penalty applied on blank symbol during decoding.
hotwords_file:
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space.
@@ -172,6 +175,7 @@ class OnlineRecognizer(object):
max_active_paths=max_active_paths,
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
blank_penalty=blank_penalty,
)
self.recognizer = _Recognizer(recognizer_config)