Add lm decode for the Python API. (#353)
* Add lm decode for the Python API. * fix style. * Fix LogAdd, Shouldn't double lm_log_prob when merge same prefix path * sort the import alphabetically
This commit is contained in:
@@ -115,6 +115,24 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""Used only when --decoding-method is modified_beam_search.
|
||||||
|
path of language model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""Used only when --decoding-method is modified_beam_search.
|
||||||
|
scale of language model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--provider",
|
"--provider",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -215,6 +233,8 @@ def main():
|
|||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
max_active_paths=args.max_active_paths,
|
max_active_paths=args.max_active_paths,
|
||||||
|
lm=args.lm,
|
||||||
|
lm_scale=args.lm_scale,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) {
|
|||||||
hyps_dict_[key] = std::move(hyp);
|
hyps_dict_[key] = std::move(hyp);
|
||||||
} else {
|
} else {
|
||||||
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
|
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
|
||||||
|
|
||||||
if (it->second.lm_log_prob != 0 && hyp.lm_log_prob != 0) {
|
|
||||||
it->second.lm_log_prob =
|
|
||||||
LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
py::arg("hotwords_score") = 0)
|
py::arg("hotwords_score") = 0)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import List, Optional
|
|||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
EndpointConfig,
|
EndpointConfig,
|
||||||
FeatureExtractorConfig,
|
FeatureExtractorConfig,
|
||||||
|
OnlineLMConfig,
|
||||||
OnlineModelConfig,
|
OnlineModelConfig,
|
||||||
OnlineParaformerModelConfig,
|
OnlineParaformerModelConfig,
|
||||||
OnlineRecognizer as _Recognizer,
|
OnlineRecognizer as _Recognizer,
|
||||||
@@ -46,6 +47,8 @@ class OnlineRecognizer(object):
|
|||||||
hotwords_file: str = "",
|
hotwords_file: str = "",
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
model_type: str = "",
|
model_type: str = "",
|
||||||
|
lm: str = "",
|
||||||
|
lm_scale: float = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -137,10 +140,22 @@ class OnlineRecognizer(object):
|
|||||||
"Please use --decoding-method=modified_beam_search when using "
|
"Please use --decoding-method=modified_beam_search when using "
|
||||||
f"--hotwords-file. Currently given: {decoding_method}"
|
f"--hotwords-file. Currently given: {decoding_method}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if lm and decoding_method != "modified_beam_search":
|
||||||
|
raise ValueError(
|
||||||
|
"Please use --decoding-method=modified_beam_search when using "
|
||||||
|
f"--lm. Currently given: {decoding_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
lm_config = OnlineLMConfig(
|
||||||
|
model=lm,
|
||||||
|
scale=lm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
recognizer_config = OnlineRecognizerConfig(
|
recognizer_config = OnlineRecognizerConfig(
|
||||||
feat_config=feat_config,
|
feat_config=feat_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
lm_config=lm_config,
|
||||||
endpoint_config=endpoint_config,
|
endpoint_config=endpoint_config,
|
||||||
enable_endpoint=enable_endpoint_detection,
|
enable_endpoint=enable_endpoint_detection,
|
||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
|
|||||||
Reference in New Issue
Block a user