diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index a52b5b91..c03b9542 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then python3 ./python-api-examples/online-decode-files.py \ --tokens=$repo/tokens.txt \ --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ $repo/test_wavs/0.wav \ $repo/test_wavs/1.wav \ @@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \ python3 ./python-api-examples/offline-decode-files.py \ --tokens=$repo/tokens.txt \ --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ $repo/test_wavs/0.wav \ $repo/test_wavs/1.wav \ @@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + ln -s $repo $PWD/ + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + + python3 ./python-api-examples/inverse-text-normalization-offline-asr.py + + rm -rfv sherpa-onnx-paraformer-zh-2023-03-28 + rm -rf $repo fi diff --git a/python-api-examples/inverse-text-normalization-offline-asr.py b/python-api-examples/inverse-text-normalization-offline-asr.py new file mode 100755 index 00000000..3228e01b --- /dev/null +++ b/python-api-examples/inverse-text-normalization-offline-asr.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2024 Xiaomi Corporation + +""" +This script shows how to use inverse text normalization with non-streaming ASR. + +Usage: + +(1) Download the test model + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 +tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 +rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 + +(2) Download rule fst + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + +Please refer to +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb +for how itn_zh_number.fst is generated. + +(3) Download test wave + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + +(4) Run this script + +python3 ./python-api-examples/inverse-text-normalization-offline-asr.py +""" +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" + tokens = "./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" + rule_fsts = "./itn_zh_number.fst" + + if ( + not Path(model).is_file() + or not Path(tokens).is_file() + or not Path(rule_fsts).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + debug=True, + rule_fsts=rule_fsts, + ) + + +def main(): + recognizer = create_recognizer() + wave_filename = "./itn-zh-number.wav" + if not Path(wave_filename).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index c64da12a..cbe9a9e8 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(config), + config_(config), symbol_table_(config_.model_config.tokens), model_(OfflineCtcModel::Create(config_.model_config)) { Init(); @@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { #if __ANDROID_API__ >= 9 OfflineRecognizerCtcImpl(AAssetManager *mgr, const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config_.model_config.tokens), model_(OfflineCtcModel::Create(mgr, config_.model_config)) { Init(); @@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); ss[i]->SetResult(r); } } @@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { auto r = Convert(results[0], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); s->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 65642577..546d0f9b 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -5,7 +5,18 @@ #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include +#include +#include +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" @@ -316,4 +327,111 @@ std::unique_ptr OfflineRecognizerImpl::Create( } #endif +OfflineRecognizerImpl::OfflineRecognizerImpl( + const OfflineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + itn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } +} + +#if __ANDROID_API__ >= 9 +OfflineRecognizerImpl::OfflineRecognizerImpl( + AAssetManager *mgr, const OfflineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + itn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) +} +#endif + +std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( + std::string text) const { + if (!itn_list_.empty()) { + for (const auto &tn : itn_list_) { + text = tn->Normalize(text); + if (config_.model_config.debug) { + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str()); + } + } + } + + return text; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h index b849de65..1ba268c1 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -14,6 +14,7 @@ #include "android/asset_manager_jni.h" #endif +#include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-stream.h" @@ -22,10 +23,15 @@ namespace sherpa_onnx { class OfflineRecognizerImpl { public: + explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config); + static std::unique_ptr Create( const OfflineRecognizerConfig &config); #if __ANDROID_API__ >= 9 + OfflineRecognizerImpl(AAssetManager *mgr, + const OfflineRecognizerConfig &config); + static std::unique_ptr Create( AAssetManager *mgr, const OfflineRecognizerConfig &config); #endif @@ -41,6 +47,15 @@ class OfflineRecognizerImpl { virtual std::unique_ptr CreateStream() const = 0; virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; + + std::string ApplyInverseTextNormalization(std::string text) const; + + private: + OfflineRecognizerConfig config_; + // for inverse text normalization. Used only if + // config.rule_fsts is not empty or + // config.rule_fars is not empty + std::vector> itn_list_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 3bcaf390..a0d4af3b 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerParaformerImpl( const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(config), + config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config.model_config)) { if (config.decoding_method == "greedy_search") { @@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { #if __ANDROID_API__ >= 9 OfflineRecognizerParaformerImpl(AAssetManager *mgr, const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config_.model_config.tokens), model_(std::make_unique(mgr, config.model_config)) { @@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_); + r.text = ApplyInverseTextNormalization(std::move(r.text)); ss[i]->SetResult(r); } } diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 265f42bb..13357f79 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerTransducerImpl( const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(config), + config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config_.model_config)) { if (config_.decoding_method == "greedy_search") { @@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OfflineRecognizerTransducerImpl( AAssetManager *mgr, const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config_.model_config.tokens), model_(std::make_unique(mgr, config_.model_config)) { @@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); ss[i]->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index 127fe343..d5902b05 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerTransducerNeMoImpl( const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(config), + config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique( config_.model_config)) { @@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OfflineRecognizerTransducerNeMoImpl( AAssetManager *mgr, const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config_.model_config.tokens), model_(std::make_unique( mgr, config_.model_config)) { @@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_, frame_shift_ms, model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); ss[i]->SetResult(r); } diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index d224c860..35891760 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(config), + config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config.model_config)) { Init(); @@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { #if __ANDROID_API__ >= 9 OfflineRecognizerWhisperImpl(AAssetManager *mgr, const OfflineRecognizerConfig &config) - : config_(config), + : OfflineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config_.model_config.tokens), model_( std::make_unique(mgr, config.model_config)) { @@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { std::move(cross_kv.second)); auto r = Convert(results[0], symbol_table_); + r.text = ApplyInverseTextNormalization(std::move(r.text)); s->SetResult(r); } catch (const Ort::Exception &ex) { SHERPA_ONNX_LOGE( diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index d6ba4905..1285a5cd 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -10,7 +10,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" - +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register("hotwords-score", &hotwords_score, "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + + po->Register( + "rule-fsts", &rule_fsts, + "If not empty, it specifies fsts for inverse text normalization. " + "If there are multiple fsts, they are separated by a comma."); + + po->Register( + "rule-fars", &rule_fars, + "If not empty, it specifies fst archives for inverse text normalization. " + "If there are multiple archives, they are separated by a comma."); } bool OfflineRecognizerConfig::Validate() const { @@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const { if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { SHERPA_ONNX_LOGE( "Please use --decoding-method=modified_beam_search if you" - " provide --hotwords-file. Given --decoding-method=%s", + " provide --hotwords-file. Given --decoding-method='%s'", decoding_method.c_str()); return false; } @@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const { return false; } + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", + hotwords_file.c_str()); + return false; + } + + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (!rule_fars.empty()) { + std::vector files; + SplitStringToVector(rule_fars, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); + return false; + } + } + } + return model_config.Validate(); } @@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; os << "hotwords_score=" << hotwords_score << ", "; - os << "blank_penalty=" << blank_penalty << ")"; + os << "blank_penalty=" << blank_penalty << ", "; + os << "rule_fsts=\"" << rule_fsts << "\", "; + os << "rule_fars=\"" << rule_fars << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index e93d7edc..9290a53b 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -40,6 +40,12 @@ struct OfflineRecognizerConfig { float blank_penalty = 0.0; + // If there are multiple rules, they are applied from left to right. + std::string rule_fsts; + + // If there are multiple FST archives, they are applied from left to right. + std::string rule_fars; + // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search @@ -50,7 +56,8 @@ struct OfflineRecognizerConfig { const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, - float blank_penalty) + float blank_penalty, const std::string &rule_fsts, + const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -59,7 +66,9 @@ struct OfflineRecognizerConfig { max_active_paths(max_active_paths), hotwords_file(hotwords_file), hotwords_score(hotwords_score), - blank_penalty(blank_penalty) {} + blank_penalty(blank_penalty), + rule_fsts(rule_fsts), + rule_fars(rule_fars) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 5ef9d4f2..2a603e08 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def(py::init(), + float, const std::string &, const std::string &>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0) + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, + py::arg("rule_fsts") = "", py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("blank_penalty", &PyClass::blank_penalty) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("rule_fars", &PyClass::rule_fars) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 480ea23c..2fade069 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -54,6 +54,8 @@ class OfflineRecognizer(object): debug: bool = False, provider: str = "cpu", model_type: str = "transducer", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -107,6 +109,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -143,6 +151,8 @@ class OfflineRecognizer(object): hotwords_file=hotwords_file, hotwords_score=hotwords_score, blank_penalty=blank_penalty, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -159,6 +169,8 @@ class OfflineRecognizer(object): decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -186,6 +198,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -206,6 +224,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -222,6 +242,8 @@ class OfflineRecognizer(object): decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -251,6 +273,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -271,6 +299,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -287,6 +317,8 @@ class OfflineRecognizer(object): decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -315,6 +347,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -335,6 +373,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -353,6 +393,8 @@ class OfflineRecognizer(object): debug: bool = False, provider: str = "cpu", tail_paddings: int = -1, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -389,6 +431,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -415,6 +463,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -431,6 +481,8 @@ class OfflineRecognizer(object): decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -458,6 +510,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -478,6 +536,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -494,6 +554,8 @@ class OfflineRecognizer(object): decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -522,6 +584,12 @@ class OfflineRecognizer(object): True to show debug messages. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -542,6 +610,8 @@ class OfflineRecognizer(object): feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config