diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 2fade069..f0e9a45f 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -6,6 +6,7 @@ from typing import List, Optional from _sherpa_onnx import ( FeatureExtractorConfig, OfflineCtcFstDecoderConfig, + OfflineLMConfig, OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, @@ -56,6 +57,8 @@ class OfflineRecognizer(object): model_type: str = "transducer", rule_fsts: str = "", rule_fars: str = "", + lm: str = "", + lm_scale: float = 0.1, ): """ Please refer to @@ -143,9 +146,21 @@ class OfflineRecognizer(object): f"--hotwords-file. Currently given: {decoding_method}" ) + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OfflineLMConfig( + model=lm, + scale=lm_scale, + ) + recognizer_config = OfflineRecognizerConfig( feat_config=feat_config, model_config=model_config, + lm_config=lm_config, decoding_method=decoding_method, max_active_paths=max_active_paths, hotwords_file=hotwords_file,