Add lm decode for the Python API. (#353)

* Add lm decode for the Python API.

* fix style.

* Fix LogAdd,

	Shouldn't double lm_log_prob when merge same prefix path

* sort the import alphabetically
This commit is contained in:
Peng He
2023-10-13 11:15:16 +08:00
committed by GitHub
parent 323f532ad2
commit 4771c9275c
4 changed files with 36 additions and 5 deletions

View File

@@ -115,6 +115,24 @@ def get_args():
""", """,
) )
parser.add_argument(
"--lm",
type=str,
default="",
help="""Used only when --decoding-method is modified_beam_search.
path of language model.
""",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.1,
help="""Used only when --decoding-method is modified_beam_search.
scale of language model.
""",
)
parser.add_argument( parser.add_argument(
"--provider", "--provider",
type=str, type=str,
@@ -215,6 +233,8 @@ def main():
feature_dim=80, feature_dim=80,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths, max_active_paths=args.max_active_paths,
lm=args.lm,
lm_scale=args.lm_scale,
hotwords_file=args.hotwords_file, hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score, hotwords_score=args.hotwords_score,
) )

View File

@@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) {
hyps_dict_[key] = std::move(hyp); hyps_dict_[key] = std::move(hyp);
} else { } else {
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
if (it->second.lm_log_prob != 0 && hyp.lm_log_prob != 0) {
it->second.lm_log_prob =
LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
}
} }
} }

View File

@@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("hotwords_score") = 0) py::arg("hotwords_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("decoding_method", &PyClass::decoding_method)

View File

@@ -5,6 +5,7 @@ from typing import List, Optional
from _sherpa_onnx import ( from _sherpa_onnx import (
EndpointConfig, EndpointConfig,
FeatureExtractorConfig, FeatureExtractorConfig,
OnlineLMConfig,
OnlineModelConfig, OnlineModelConfig,
OnlineParaformerModelConfig, OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer, OnlineRecognizer as _Recognizer,
@@ -46,6 +47,8 @@ class OnlineRecognizer(object):
hotwords_file: str = "", hotwords_file: str = "",
provider: str = "cpu", provider: str = "cpu",
model_type: str = "", model_type: str = "",
lm: str = "",
lm_scale: float = 0.1,
): ):
""" """
Please refer to Please refer to
@@ -137,10 +140,22 @@ class OnlineRecognizer(object):
"Please use --decoding-method=modified_beam_search when using " "Please use --decoding-method=modified_beam_search when using "
f"--hotwords-file. Currently given: {decoding_method}" f"--hotwords-file. Currently given: {decoding_method}"
) )
if lm and decoding_method != "modified_beam_search":
raise ValueError(
"Please use --decoding-method=modified_beam_search when using "
f"--lm. Currently given: {decoding_method}"
)
lm_config = OnlineLMConfig(
model=lm,
scale=lm_scale,
)
recognizer_config = OnlineRecognizerConfig( recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config, feat_config=feat_config,
model_config=model_config, model_config=model_config,
lm_config=lm_config,
endpoint_config=endpoint_config, endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,