Encode hotwords in C++ side (#828)

* Encode hotwords in C++ side
This commit is contained in:
Wei Kang
2024-05-20 19:41:36 +08:00
committed by GitHub
parent 8af2af8466
commit b012b78ceb
43 changed files with 714 additions and 102 deletions

View File

@@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
int32_t, bool, const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
@@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
@@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

View File

@@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) {
const OnlineZipformer2CtcModelConfig &,
const OnlineNeMoCtcModelConfig &, const std::string &,
int32_t, int32_t, bool, const std::string &,
const std::string &, const std::string &,
const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
@@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) {
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("provider") = "cpu",
py::arg("model_type") = "")
py::arg("model_type") = "", py::arg("modeling_unit") = "",
py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
@@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

View File

@@ -49,6 +49,8 @@ class OfflineRecognizer(object):
hotwords_file: str = "",
hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
debug: bool = False,
provider: str = "cpu",
model_type: str = "transducer",
@@ -91,6 +93,16 @@ class OfflineRecognizer(object):
hotwords_file is given with modified_beam_search as decoding method.
blank_penalty:
The penalty applied on blank symbol during decoding.
modeling_unit:
The modeling unit of the model, commonly used units are bpe, cjkchar,
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
provided, we need it to encode the hotwords into token sequence.
and the modeling unit is bpe or cjkchar+bpe.
bpe_vocab:
The vocabulary generated by google's sentencepiece program.
It is a file has two columns, one is the token, the other is
the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided
debug:
True to show debug messages.
provider:
@@ -107,6 +119,8 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
model_type=model_type,
)

View File

@@ -58,6 +58,8 @@ class OnlineRecognizer(object):
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
@@ -136,6 +138,16 @@ class OnlineRecognizer(object):
model_type:
Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice.
modeling_unit:
The modeling unit of the model, commonly used units are bpe, cjkchar,
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
provided, we need it to encode the hotwords into token sequence.
bpe_vocab:
The vocabulary generated by google's sentencepiece program.
It is a file has two columns, one is the token, the other is
the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided
and the modeling unit is bpe or cjkchar+bpe.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -157,6 +169,8 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type=model_type,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
debug=debug,
)