diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index 160478f2..747f190d 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do done done +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 +tar xf dict.tar.bz2 +rm dict.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt + +for m in model.onnx model.int8.onnx; do + for use_itn in 0 1; do + echo "$m $w $use_itn" + time $EXE \ + --tokens=$repo/tokens.txt \ + --sense-voice-model=$repo/$m \ + --sense-voice-use-itn=$use_itn \ + --hr-lexicon=./lexicon.txt \ + --hr-dict-dir=./dict \ + --hr-rule-fsts=./replace.fst \ + ./test-hr.wav + done +done + +rm -rf dict replace.fst test-hr.wav lexicon.txt # test wav reader for non-standard wav files waves=( diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 3704e7fb..b9cbe8b9 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -95,6 +95,18 @@ rm $name ls -lh $repo python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 +tar xf dict.tar.bz2 +rm dict.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt + +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py + +rm -rf dict replace.fst test-hr.wav lexicon.txt + if [[ $(uname) == Linux ]]; then # It needs ffmpeg log "generate subtitles (Chinese)" diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake index 034d8c44..b1de1c69 100644 --- a/cmake/kaldifst.cmake +++ b/cmake/kaldifst.cmake @@ -1,18 +1,18 @@ function(download_kaldifst) include(FetchContent) - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz") - set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz") - set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f") + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz") + set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz") + set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2") # If you don't have access to the Internet, # please pre-download kaldifst set(possible_file_locations - $ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz - ${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz - ${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz - /tmp/kaldifst-1.7.11.tar.gz - /star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz + $ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz + ${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz + ${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz + /tmp/kaldifst-1.7.13.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py b/python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py new file mode 100755 index 00000000..f4952931 --- /dev/null +++ b/python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming SenseVoice CTC model from +https://github.com/FunAudioLLM/SenseVoice +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 +tar xf dict.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx" + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" + test_wav = "./test-hr.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + and + https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_sense_voice( + model=model, + tokens=tokens, + use_itn=True, + debug=True, + hr_lexicon="./lexicon.txt", + hr_dict_dir="./dict", + hr_rule_fsts="./replace.fst", + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index a84a9f4a..b9def8c8 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -20,7 +20,9 @@ set(sources features.cc file-utils.cc fst-utils.cc + homophone-replacer.cc hypothesis.cc + jieba.cc keyword-spotter-impl.cc keyword-spotter.cc offline-ctc-fst-decoder-config.cc diff --git a/sherpa-onnx/csrc/homophone-replacer.cc b/sherpa-onnx/csrc/homophone-replacer.cc new file mode 100644 index 00000000..dc938032 --- /dev/null +++ b/sherpa-onnx/csrc/homophone-replacer.cc @@ -0,0 +1,278 @@ +// sherpa-onnx/csrc/homophone-replacer.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/homophone-replacer.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/jieba.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +void HomophoneReplacerConfig::Register(ParseOptions *po) { + po->Register("hr-dict-dir", &dict_dir, + "The dict directory for jieba used by HomophoneReplacer"); + + po->Register("hr-lexicon", &lexicon, + "Path to lexicon.txt used by HomophoneReplacer."); + + po->Register("hr-rule-fsts", &rule_fsts, + "Fst files for HomophoneReplacer. If there are multiple, they " + "are separated by a comma. E.g., a.fst,b.fst,c.fst"); +} + +bool HomophoneReplacerConfig::Validate() const { + if (!dict_dir.empty()) { + std::vector required_files = { + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", + "idf.utf8", "stop_words.utf8", + }; + + for (const auto &f : required_files) { + if (!FileExists(dict_dir + "/" + f)) { + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir", + dict_dir.c_str(), f.c_str()); + return false; + } + } + } + + if (!lexicon.empty() && !FileExists(lexicon)) { + SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str()); + return false; + } + + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + + if (files.size() > 1) { + SHERPA_ONNX_LOGE("Only 1 file is supported now."); + SHERPA_ONNX_EXIT(-1); + } + + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + return true; +} + +std::string HomophoneReplacerConfig::ToString() const { + std::ostringstream os; + + os << "HomophoneReplacerConfig("; + os << "dict_dir=\"" << dict_dir << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "rule_fsts=\"" << rule_fsts << "\")"; + + return os.str(); +} + +class HomophoneReplacer::Impl { + public: + explicit Impl(const HomophoneReplacerConfig &config) : config_(config) { + jieba_ = InitJieba(config.dict_dir); + + { + std::ifstream is(config.lexicon); + InitLexicon(is); + } + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + replacer_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.debug) { + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str()); + } + replacer_list_.push_back(std::make_unique(f)); + } + } + } + + template + Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) { + jieba_ = InitJieba(config.dict_dir); + { + auto buf = ReadFile(mgr, config.lexicon); + + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + replacer_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.debug) { + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + replacer_list_.push_back( + std::make_unique(is)); + } + } + } + + std::string Apply(const std::string &text) const { + bool is_hmm = true; + + std::vector words; + jieba_->Cut(text, words, is_hmm); + if (config_.debug) { + SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str()); + std::ostringstream os; + os << "After jieba: "; + std::string sep; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + // convert words to pronunciations + std::vector pronunciations; + + for (const auto &w : words) { + auto p = ConvertWordToPronunciation(w); + if (config_.debug) { + SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str()); + } + pronunciations.push_back(std::move(p)); + } + + std::string ans; + for (const auto &r : replacer_list_) { + ans = r->Normalize(words, pronunciations); + // TODO(fangjun): We support only 1 rule fst at present. + break; + } + + return ans; + } + + private: + std::string ConvertWordToPronunciation(const std::string &word) const { + if (word2pron_.count(word)) { + return word2pron_.at(word); + } + + if (word.size() <= 3) { + // not a Chinese character + return word; + } + + std::vector words = SplitUtf8(word); + std::string ans; + for (const auto &w : words) { + if (word2pron_.count(w)) { + ans.append(word2pron_.at(w)); + } else { + ans.append(w); + } + } + + return ans; + } + + void InitLexicon(std::istream &is) { + std::string word; + std::string pron; + std::string p; + + std::string line; + int32_t line_num = 0; + int32_t num_warn = 0; + while (std::getline(is, line)) { + ++line_num; + std::istringstream iss(line); + + pron.clear(); + iss >> word; + ToLowerCase(&word); + + if (word2pron_.count(word)) { + num_warn += 1; + if (num_warn < 10) { + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); + } + continue; + } + + while (iss >> p) { + pron.append(std::move(p)); + } + + if (pron.empty()) { + SHERPA_ONNX_LOGE( + "Empty pronunciation for word '%s' at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); + continue; + } + + word2pron_.insert({std::move(word), std::move(pron)}); + } + } + + private: + HomophoneReplacerConfig config_; + std::unique_ptr jieba_; + std::vector> replacer_list_; + std::unordered_map word2pron_; +}; + +HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config) + : impl_(std::make_unique(config)) {} + +template +HomophoneReplacer::HomophoneReplacer(Manager *mgr, + const HomophoneReplacerConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +HomophoneReplacer::~HomophoneReplacer() = default; + +std::string HomophoneReplacer::Apply(const std::string &text) const { + return impl_->Apply(text); +} + +#if __ANDROID_API__ >= 9 +template HomophoneReplacer::HomophoneReplacer( + AAssetManager *mgr, const HomophoneReplacerConfig &config); +#endif + +#if __OHOS__ +template HomophoneReplacer::HomophoneReplacer( + NativeResourceManager *mgr, const HomophoneReplacerConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/homophone-replacer.h b/sherpa-onnx/csrc/homophone-replacer.h new file mode 100644 index 00000000..1f036419 --- /dev/null +++ b/sherpa-onnx/csrc/homophone-replacer.h @@ -0,0 +1,58 @@ +// sherpa-onnx/csrc/homophone-replacer.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ +#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ + +#include +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct HomophoneReplacerConfig { + std::string dict_dir; + std::string lexicon; + + // comma separated fst files, e.g. a.fst,b.fst,c.fst + std::string rule_fsts; + + bool debug; + + HomophoneReplacerConfig() = default; + + HomophoneReplacerConfig(const std::string &dict_dir, + const std::string &lexicon, + const std::string &rule_fsts, bool debug) + : dict_dir(dict_dir), + lexicon(lexicon), + rule_fsts(rule_fsts), + debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class HomophoneReplacer { + public: + explicit HomophoneReplacer(const HomophoneReplacerConfig &config); + + template + HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config); + + ~HomophoneReplacer(); + + std::string Apply(const std::string &text) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index 189520c4..4d012f22 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -19,8 +19,8 @@ #include "rawfile/raw_file_manager.h" #endif -#include "cppjieba/Jieba.hpp" #include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/jieba.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -41,20 +41,7 @@ class JiebaLexicon::Impl { Impl(const std::string &lexicon, const std::string &tokens, const std::string &dict_dir, bool debug) : debug_(debug) { - std::string dict = dict_dir + "/jieba.dict.utf8"; - std::string hmm = dict_dir + "/hmm_model.utf8"; - std::string user_dict = dict_dir + "/user.dict.utf8"; - std::string idf = dict_dir + "/idf.utf8"; - std::string stop_word = dict_dir + "/stop_words.utf8"; - - AssertFileExists(dict); - AssertFileExists(hmm); - AssertFileExists(user_dict); - AssertFileExists(idf); - AssertFileExists(stop_word); - - jieba_ = - std::make_unique(dict, hmm, user_dict, idf, stop_word); + jieba_ = InitJieba(dict_dir); { std::ifstream is(tokens); @@ -71,20 +58,7 @@ class JiebaLexicon::Impl { Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, const std::string &dict_dir, bool debug) : debug_(debug) { - std::string dict = dict_dir + "/jieba.dict.utf8"; - std::string hmm = dict_dir + "/hmm_model.utf8"; - std::string user_dict = dict_dir + "/user.dict.utf8"; - std::string idf = dict_dir + "/idf.utf8"; - std::string stop_word = dict_dir + "/stop_words.utf8"; - - AssertFileExists(dict); - AssertFileExists(hmm); - AssertFileExists(user_dict); - AssertFileExists(idf); - AssertFileExists(stop_word); - - jieba_ = - std::make_unique(dict, hmm, user_dict, idf, stop_word); + jieba_ = InitJieba(dict_dir); { auto buf = ReadFile(mgr, tokens); diff --git a/sherpa-onnx/csrc/jieba.cc b/sherpa-onnx/csrc/jieba.cc new file mode 100644 index 00000000..655341ec --- /dev/null +++ b/sherpa-onnx/csrc/jieba.cc @@ -0,0 +1,32 @@ +// sherpa-onnx/csrc/jieba.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/jieba.h" + +#include "sherpa-onnx/csrc/file-utils.h" + +namespace sherpa_onnx { + +std::unique_ptr InitJieba(const std::string &dict_dir) { + if (dict_dir.empty()) { + return {}; + } + + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + return std::make_unique(dict, hmm, user_dict, idf, + stop_word); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/jieba.h b/sherpa-onnx/csrc/jieba.h new file mode 100644 index 00000000..69a9418c --- /dev/null +++ b/sherpa-onnx/csrc/jieba.h @@ -0,0 +1,18 @@ +// sherpa-onnx/csrc/jieba.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_JIEBA_H_ +#define SHERPA_ONNX_CSRC_JIEBA_H_ + +#include +#include + +#include "cppjieba/Jieba.hpp" + +namespace sherpa_onnx { + +std::unique_ptr InitJieba(const std::string &dict_dir); +} + +#endif // SHERPA_ONNX_CSRC_JIEBA_H_ diff --git a/sherpa-onnx/csrc/kokoro-multi-lang-lexicon.cc b/sherpa-onnx/csrc/kokoro-multi-lang-lexicon.cc index 15c415a2..707e68ac 100644 --- a/sherpa-onnx/csrc/kokoro-multi-lang-lexicon.cc +++ b/sherpa-onnx/csrc/kokoro-multi-lang-lexicon.cc @@ -22,11 +22,11 @@ #include -#include "cppjieba/Jieba.hpp" #include "espeak-ng/speak_lib.h" #include "phoneme_ids.hpp" #include "phonemize.hpp" #include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/jieba.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl { InitLexicon(lexicon); - InitJieba(dict_dir); + jieba_ = InitJieba(dict_dir); InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc } @@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl { InitLexicon(mgr, lexicon); // we assume you have copied dict_dir and data_dir from assets to some path - InitJieba(dict_dir); + jieba_ = InitJieba(dict_dir); InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc } @@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl { } } - void InitJieba(const std::string &dict_dir) { - std::string dict = dict_dir + "/jieba.dict.utf8"; - std::string hmm = dict_dir + "/hmm_model.utf8"; - std::string user_dict = dict_dir + "/user.dict.utf8"; - std::string idf = dict_dir + "/idf.utf8"; - std::string stop_word = dict_dir + "/stop_words.utf8"; - - AssertFileExists(dict); - AssertFileExists(hmm); - AssertFileExists(user_dict); - AssertFileExists(idf); - AssertFileExists(stop_word); - - jieba_ = - std::make_unique(dict, hmm, user_dict, idf, stop_word); - } - private: OfflineTtsKokoroModelMetaData meta_data_; diff --git a/sherpa-onnx/csrc/melo-tts-lexicon.cc b/sherpa-onnx/csrc/melo-tts-lexicon.cc index 48b854f8..0a70788e 100644 --- a/sherpa-onnx/csrc/melo-tts-lexicon.cc +++ b/sherpa-onnx/csrc/melo-tts-lexicon.cc @@ -19,8 +19,8 @@ #include "rawfile/raw_file_manager.h" #endif -#include "cppjieba/Jieba.hpp" #include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/jieba.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl { const std::string &dict_dir, const OfflineTtsVitsModelMetaData &meta_data, bool debug) : meta_data_(meta_data), debug_(debug) { - std::string dict = dict_dir + "/jieba.dict.utf8"; - std::string hmm = dict_dir + "/hmm_model.utf8"; - std::string user_dict = dict_dir + "/user.dict.utf8"; - std::string idf = dict_dir + "/idf.utf8"; - std::string stop_word = dict_dir + "/stop_words.utf8"; - - AssertFileExists(dict); - AssertFileExists(hmm); - AssertFileExists(user_dict); - AssertFileExists(idf); - AssertFileExists(stop_word); - - jieba_ = - std::make_unique(dict, hmm, user_dict, idf, stop_word); + jieba_ = InitJieba(dict_dir); { std::ifstream is(tokens); @@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl { const std::string &dict_dir, const OfflineTtsVitsModelMetaData &meta_data, bool debug) : meta_data_(meta_data), debug_(debug) { - std::string dict = dict_dir + "/jieba.dict.utf8"; - std::string hmm = dict_dir + "/hmm_model.utf8"; - std::string user_dict = dict_dir + "/user.dict.utf8"; - std::string idf = dict_dir + "/idf.utf8"; - std::string stop_word = dict_dir + "/stop_words.utf8"; - - AssertFileExists(dict); - AssertFileExists(hmm); - AssertFileExists(user_dict); - AssertFileExists(idf); - AssertFileExists(stop_word); - - jieba_ = - std::make_unique(dict, hmm, user_dict, idf, stop_word); + jieba_ = InitJieba(dict_dir); { auto buf = ReadFile(mgr, tokens); diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 30df001f..b2a1884e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); ss[i]->SetResult(r); } } @@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { auto r = Convert(results[0], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); s->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h index f206e314..159a48b8 100644 --- a/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h @@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { auto r = Convert(results[0], symbol_table_); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); s->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index c4dba8aa..9e19dfd2 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -408,6 +408,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( OfflineRecognizerImpl::OfflineRecognizerImpl( const OfflineRecognizerConfig &config) : config_(config) { + // TODO(fangjun): Refactor this function + if (!config.rule_fsts.empty()) { std::vector files; SplitStringToVector(config.rule_fsts, ",", false, &files); @@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( SHERPA_ONNX_LOGE("FST archives loaded!"); } } + + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && + !config.hr.rule_fsts.empty()) { + auto hr_config = config.hr; + hr_config.debug = config.model_config.debug; + hr_ = std::make_unique(hr_config); + } } template @@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( } // for (; !reader->Done(); reader->Next()) } // for (const auto &f : files) } // if (!config.rule_fars.empty()) + + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && + !config.hr.rule_fsts.empty()) { + auto hr_config = config.hr; + hr_config.debug = config.model_config.debug; + hr_ = std::make_unique(mgr, hr_config); + } } std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( @@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( return text; } +std::string OfflineRecognizerImpl::ApplyHomophoneReplacer( + std::string text) const { + if (hr_) { + text = hr_->Apply(text); + } + + return text; +} + void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { config_ = config; } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h index 8a6e6fcc..c749ee82 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -10,6 +10,7 @@ #include #include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-onnx/csrc/homophone-replacer.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-stream.h" @@ -48,12 +49,15 @@ class OfflineRecognizerImpl { std::string ApplyInverseTextNormalization(std::string text) const; + std::string ApplyHomophoneReplacer(std::string text) const; + private: OfflineRecognizerConfig config_; // for inverse text normalization. Used only if // config.rule_fsts is not empty or // config.rule_fars is not empty std::vector> itn_list_; + std::unique_ptr hr_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h index deec9852..49aefe70 100644 --- a/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h @@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { auto r = Convert(results[0], symbol_table_); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); s->SetResult(r); } catch (const Ort::Exception &ex) { SHERPA_ONNX_LOGE( diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 5b80c99e..40f2041f 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); ss[i]->SetResult(r); } } diff --git a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h index 7ee5e41c..82266cd7 100644 --- a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h @@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { auto r = ConvertSenseVoiceResult(results[i], symbol_table_, frame_shift_ms, subsampling_factor); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); ss[i]->SetResult(r); } } @@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { subsampling_factor); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); s->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 158be562..2f7c5bed 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); ss[i]->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index 167d1021..cf8e18da 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); ss[i]->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index 7c928b09..97f84339 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { std::string s = sym_table[i]; s = ApplyInverseTextNormalization(s); + s = ApplyHomophoneReplacer(std::move(s)); text += s; r.tokens.push_back(s); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 80bc2090..4b11683e 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { model_config.Register(po); lm_config.Register(po); ctc_fst_decoder_config.Register(po); + hr.Register(po); po->Register( "decoding-method", &decoding_method, @@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const { } } + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() && + !hr.Validate()) { + return false; + } + return model_config.Validate(); } @@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const { os << "hotwords_score=" << hotwords_score << ", "; os << "blank_penalty=" << blank_penalty << ", "; os << "rule_fsts=\"" << rule_fsts << "\", "; - os << "rule_fars=\"" << rule_fars << "\")"; + os << "rule_fars=\"" << rule_fars << "\", "; + os << "hr=" << hr.ToString() << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 3c78ea9b..1fcc1016 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -10,6 +10,7 @@ #include #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/homophone-replacer.h" #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" @@ -40,6 +41,7 @@ struct OfflineRecognizerConfig { // If there are multiple FST archives, they are applied from left to right. std::string rule_fars; + HomophoneReplacerConfig hr; // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search @@ -52,7 +54,7 @@ struct OfflineRecognizerConfig { const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, float blank_penalty, const std::string &rule_fsts, - const std::string &rule_fars) + const std::string &rule_fars, const HomophoneReplacerConfig &hr) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -63,7 +65,8 @@ struct OfflineRecognizerConfig { hotwords_score(hotwords_score), blank_penalty(blank_penalty), rule_fsts(rule_fsts), - rule_fars(rule_fars) {} + rule_fars(rule_fars), + hr(hr) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 32f6ac4d..a675fd72 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { auto r = ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); - r.text = ApplyInverseTextNormalization(r.text); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 4401591d..a3f3ef3e 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) SHERPA_ONNX_LOGE("FST archives loaded!"); } } + + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && + !config.hr.rule_fsts.empty()) { + auto hr_config = config.hr; + hr_config.debug = config.model_config.debug; + hr_ = std::make_unique(hr_config); + } } template @@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, } // for (; !reader->Done(); reader->Next()) } // for (const auto &f : files) } // if (!config.rule_fars.empty()) + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && + !config.hr.rule_fsts.empty()) { + auto hr_config = config.hr; + hr_config.debug = config.model_config.debug; + hr_ = std::make_unique(mgr, hr_config); + } } std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( @@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( return text; } +std::string OnlineRecognizerImpl::ApplyHomophoneReplacer( + std::string text) const { + if (hr_) { + text = hr_->Apply(text); + } + + return text; +} + #if __ANDROID_API__ >= 9 template OnlineRecognizerImpl::OnlineRecognizerImpl( AAssetManager *mgr, const OnlineRecognizerConfig &config); diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index b7bda786..d752bde6 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -10,6 +10,7 @@ #include #include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-onnx/csrc/homophone-replacer.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -57,6 +58,7 @@ class OnlineRecognizerImpl { virtual void Reset(OnlineStream *s) const = 0; std::string ApplyInverseTextNormalization(std::string text) const; + std::string ApplyHomophoneReplacer(std::string text) const; private: OnlineRecognizerConfig config_; @@ -64,6 +66,7 @@ class OnlineRecognizerImpl { // config.rule_fsts is not empty or // config.rule_fars is not empty std::vector> itn_list_; + std::unique_ptr hr_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index 1e02fe51..2f4b3c47 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { auto decoder_result = s->GetParaformerResult(); auto r = Convert(decoder_result, sym_); - r.text = ApplyInverseTextNormalization(r.text); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 8370397b..c136c666 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } @@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // (the encoder state buffers are kept) for (const auto &it : last_result.hyps) { auto h = it.second; - r.hyps.Add({std::vector(h.ys.end() - context_size, - h.ys.end()), + r.hyps.Add({std::vector(h.ys.end() - context_size, h.ys.end()), h.log_prob}); } - r.tokens = std::vector (last_result.tokens.end() - context_size, - last_result.tokens.end()); + r.tokens = std::vector(last_result.tokens.end() - context_size, + last_result.tokens.end()); } else { - if(config_.reset_encoder) { + if (config_.reset_encoder) { // reset encoder states, use blanks as 'ys' context s->SetStates(model_->GetEncoderInitStates()); } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index a3f2756c..6087c2b5 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index a2c76d70..3b5b30ec 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { endpoint_config.Register(po); lm_config.Register(po); ctc_fst_decoder_config.Register(po); + hr.Register(po); po->Register("enable-endpoint", &enable_endpoint, "True to enable endpoint detection. False to disable it."); @@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const { } } + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() && + !hr.Validate()) { + return false; + } + return model_config.Validate(); } @@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "temperature_scale=" << temperature_scale << ", "; os << "rule_fsts=\"" << rule_fsts << "\", "; os << "rule_fars=\"" << rule_fars << "\", "; - os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")"; + os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", "; + os << "hr=" << hr.ToString() << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 5936e0df..d52b877c 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/homophone-replacer.h" #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" @@ -107,6 +108,8 @@ struct OnlineRecognizerConfig { // currently only in `OnlineRecognizerTransducerImpl`. bool reset_encoder = false; + HomophoneReplacerConfig hr; + /// used only for modified_beam_search, if hotwords_buf is non-empty, /// the hotwords will be loaded from the buffered string instead of from the /// "hotwords_file" @@ -123,7 +126,7 @@ struct OnlineRecognizerConfig { int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, float blank_penalty, float temperature_scale, const std::string &rule_fsts, const std::string &rule_fars, - bool reset_encoder) + bool reset_encoder, const HomophoneReplacerConfig &hr) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -138,7 +141,8 @@ struct OnlineRecognizerConfig { temperature_scale(temperature_scale), rule_fsts(rule_fsts), rule_fars(rule_fars), - reset_encoder(reset_encoder) {} + reset_encoder(reset_encoder), + hr(hr) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h b/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h index 32fc6b8d..7ecbf26a 100644 --- a/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h +++ b/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h @@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { auto r = ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); - r.text = ApplyInverseTextNormalization(r.text); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } diff --git a/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h index 8336ed21..7f4a7bf9 100644 --- a/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h +++ b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h @@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); return r; } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index c0b5c01c..2626a251 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ set(srcs display.cc endpoint.cc features.cc + homophone-replacer.cc keyword-spotter.cc offline-ctc-fst-decoder-config.cc offline-dolphin-model-config.cc diff --git a/sherpa-onnx/python/csrc/homophone-replacer.cc b/sherpa-onnx/python/csrc/homophone-replacer.cc new file mode 100644 index 00000000..8f658a06 --- /dev/null +++ b/sherpa-onnx/python/csrc/homophone-replacer.cc @@ -0,0 +1,28 @@ +// sherpa-onnx/python/csrc/homophone-replacer.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/homophone-replacer.h" + +#include + +#include "sherpa-onnx/csrc/homophone-replacer.h" + +namespace sherpa_onnx { + +void PybindHomophoneReplacer(py::module *m) { + using PyClass = HomophoneReplacerConfig; + py::class_(*m, "HomophoneReplacerConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"), + py::arg("debug") = false) + .def_readwrite("dict_dir", &PyClass::dict_dir) + .def_readwrite("lexicon", &PyClass::lexicon) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("debug", &PyClass::debug) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/homophone-replacer.h b/sherpa-onnx/python/csrc/homophone-replacer.h new file mode 100644 index 00000000..2d4528fa --- /dev/null +++ b/sherpa-onnx/python/csrc/homophone-replacer.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/homophone-replacer.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindHomophoneReplacer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 2a603e08..13b44d7f 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def(py::init(), + float, const std::string &, const std::string &, + const HomophoneReplacerConfig &>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, - py::arg("rule_fsts") = "", py::arg("rule_fars") = "") + py::arg("rule_fsts") = "", py::arg("rule_fars") = "", + py::arg("hr") = HomophoneReplacerConfig{}) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fars", &PyClass::rule_fars) + .def_readwrite("hr", &PyClass::hr) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 38d4e776..b9c74d54 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { const OnlineLMConfig &, const EndpointConfig &, const OnlineCtcFstDecoderConfig &, bool, const std::string &, int32_t, const std::string &, float, - float, float, const std::string &, const std::string &, bool>(), + float, float, const std::string &, const std::string &, + bool, const HomophoneReplacerConfig &>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), @@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", - py::arg("rule_fars") = "", py::arg("reset_encoder") = false) + py::arg("rule_fars") = "", py::arg("reset_encoder") = false, + py::arg("hr") = HomophoneReplacerConfig{}) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fars", &PyClass::rule_fars) .def_readwrite("reset_encoder", &PyClass::reset_encoder) + .def_readwrite("hr", &PyClass::hr) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index c00b4644..a6fa8cba 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -10,6 +10,7 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/homophone-replacer.h" #include "sherpa-onnx/python/csrc/keyword-spotter.h" #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/offline-lm-config.h" @@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindAudioTagging(&m); PybindOfflinePunctuation(&m); PybindOnlinePunctuation(&m); + PybindHomophoneReplacer(&m); PybindFeatures(&m); PybindOnlineCtcFstDecoderConfig(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 50572366..eae50c4e 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -5,6 +5,7 @@ from typing import List, Optional from _sherpa_onnx import ( FeatureExtractorConfig, + HomophoneReplacerConfig, OfflineCtcFstDecoderConfig, OfflineDolphinModelConfig, OfflineFireRedAsrModelConfig, @@ -64,6 +65,9 @@ class OfflineRecognizer(object): rule_fars: str = "", lm: str = "", lm_scale: float = 0.1, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -181,6 +185,11 @@ class OfflineRecognizer(object): blank_penalty=blank_penalty, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -201,6 +210,9 @@ class OfflineRecognizer(object): use_itn: bool = False, rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -263,6 +275,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -281,6 +298,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -336,6 +356,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -354,6 +379,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -411,6 +439,9 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -429,6 +460,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -483,6 +517,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -501,6 +540,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -557,6 +599,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -577,6 +624,9 @@ class OfflineRecognizer(object): tail_paddings: int = -1, rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -647,6 +697,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -664,6 +719,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -719,6 +777,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -738,6 +801,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -800,6 +866,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -818,6 +889,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -873,6 +947,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -891,6 +970,9 @@ class OfflineRecognizer(object): provider: str = "cpu", rule_fsts: str = "", rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -947,6 +1029,11 @@ class OfflineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 78e383cc..747e4e50 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -3,25 +3,26 @@ from pathlib import Path from typing import List, Optional from _sherpa_onnx import ( + CudaConfig, EndpointConfig, FeatureExtractorConfig, + HomophoneReplacerConfig, + OnlineCtcFstDecoderConfig, OnlineLMConfig, OnlineModelConfig, + OnlineNeMoCtcModelConfig, OnlineParaformerModelConfig, ) from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( - CudaConfig, - TensorrtConfig, - ProviderConfig, OnlineRecognizerConfig, OnlineRecognizerResult, OnlineStream, OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, - OnlineNeMoCtcModelConfig, OnlineZipformer2CtcModelConfig, - OnlineCtcFstDecoderConfig, + ProviderConfig, + TensorrtConfig, ) @@ -82,9 +83,12 @@ class OnlineRecognizer(object): trt_detailed_build_log: bool = False, trt_engine_cache_enable: bool = True, trt_timing_cache_enable: bool = True, - trt_engine_cache_path: str ="", - trt_timing_cache_path: str ="", + trt_engine_cache_path: str = "", + trt_timing_cache_path: str = "", trt_dump_subgraphs: bool = False, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -228,27 +232,27 @@ class OnlineRecognizer(object): ) cuda_config = CudaConfig( - cudnn_conv_algo_search=cudnn_conv_algo_search, + cudnn_conv_algo_search=cudnn_conv_algo_search, ) trt_config = TensorrtConfig( - trt_max_workspace_size=trt_max_workspace_size, - trt_max_partition_iterations=trt_max_partition_iterations, - trt_min_subgraph_size=trt_min_subgraph_size, - trt_fp16_enable=trt_fp16_enable, - trt_detailed_build_log=trt_detailed_build_log, - trt_engine_cache_enable=trt_engine_cache_enable, - trt_timing_cache_enable=trt_timing_cache_enable, - trt_engine_cache_path=trt_engine_cache_path, - trt_timing_cache_path=trt_timing_cache_path, - trt_dump_subgraphs=trt_dump_subgraphs, + trt_max_workspace_size=trt_max_workspace_size, + trt_max_partition_iterations=trt_max_partition_iterations, + trt_min_subgraph_size=trt_min_subgraph_size, + trt_fp16_enable=trt_fp16_enable, + trt_detailed_build_log=trt_detailed_build_log, + trt_engine_cache_enable=trt_engine_cache_enable, + trt_timing_cache_enable=trt_timing_cache_enable, + trt_engine_cache_path=trt_engine_cache_path, + trt_timing_cache_path=trt_timing_cache_path, + trt_dump_subgraphs=trt_dump_subgraphs, ) provider_config = ProviderConfig( - trt_config=trt_config, - cuda_config=cuda_config, - provider=provider, - device=device, + trt_config=trt_config, + cuda_config=cuda_config, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -311,6 +315,11 @@ class OnlineRecognizer(object): rule_fsts=rule_fsts, rule_fars=rule_fars, reset_encoder=reset_encoder, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) @@ -336,6 +345,9 @@ class OnlineRecognizer(object): rule_fsts: str = "", rule_fars: str = "", device: int = 0, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -402,8 +414,8 @@ class OnlineRecognizer(object): ) provider_config = ProviderConfig( - provider=provider, - device=device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -434,6 +446,11 @@ class OnlineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) @@ -460,6 +477,9 @@ class OnlineRecognizer(object): rule_fsts: str = "", rule_fars: str = "", device: int = 0, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -526,8 +546,8 @@ class OnlineRecognizer(object): zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) provider_config = ProviderConfig( - provider=provider, - device=device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -563,6 +583,11 @@ class OnlineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) @@ -587,6 +612,9 @@ class OnlineRecognizer(object): rule_fsts: str = "", rule_fars: str = "", device: int = 0, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -650,8 +678,8 @@ class OnlineRecognizer(object): ) provider_config = ProviderConfig( - provider=provider, - device=device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -681,6 +709,11 @@ class OnlineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config) @@ -707,6 +740,9 @@ class OnlineRecognizer(object): rule_fsts: str = "", rule_fars: str = "", device: int = 0, + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", ): """ Please refer to @@ -775,8 +811,8 @@ class OnlineRecognizer(object): ) provider_config = ProviderConfig( - provider=provider, - device=device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -806,6 +842,11 @@ class OnlineRecognizer(object): decoding_method=decoding_method, rule_fsts=rule_fsts, rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), ) self.recognizer = _Recognizer(recognizer_config)