diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index c42c2687..5e4835b8 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -46,7 +46,17 @@ bool OfflineRecognizerConfig::Validate() const { max_active_paths); return false; } - if (!lm_config.Validate()) return false; + if (!lm_config.Validate()) { + return false; + } + } + + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + SHERPA_ONNX_LOGE( + "Please use --decoding-method=modified_beam_search if you" + " provide --hotwords-file. Given --decoding-method=%s", + decoding_method.c_str()); + return false; } return model_config.Validate(); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 31d54fe8..7e9f6e29 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -156,8 +156,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { bool has_context_graph = false; for (int32_t i = 0; i != n; ++i) { - if (!has_context_graph && ss[i]->GetContextGraph()) + if (!has_context_graph && ss[i]->GetContextGraph()) { has_context_graph = true; + } const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 687b0211..c3fb728d 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -102,8 +102,20 @@ bool OnlineRecognizerConfig::Validate() const { max_active_paths); return false; } - if (!lm_config.Validate()) return false; + + if (!lm_config.Validate()) { + return false; + } } + + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + SHERPA_ONNX_LOGE( + "Please use --decoding-method=modified_beam_search if you" + " provide --hotwords-file. Given --decoding-method=%s", + decoding_method.c_str()); + return false; + } + return model_config.Validate(); } diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 57a2302e..61158d36 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Optional - from _sherpa_onnx import ( CircularBuffer, Display, diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 6b737be9..e1c82279 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -102,6 +102,12 @@ class OfflineRecognizer(object): feature_dim=feature_dim, ) + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--hotwords-file. Currently given: {decoding_method}" + ) + recognizer_config = OfflineRecognizerConfig( feat_config=feat_config, model_config=model_config, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index e4f991a0..eabf99ec 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -132,6 +132,12 @@ class OnlineRecognizer(object): rule3_min_utterance_length=rule3_min_utterance_length, ) + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--hotwords-file. Currently given: {decoding_method}" + ) + recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config,