diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index fe665501..08eb11de 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,13 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test nemo canary" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2 +tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2 +rm sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2 +python3 ./python-api-examples/offline-nemo-canary-decode-files.py +rm -rf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 + log "test spleeter" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 diff --git a/python-api-examples/offline-nemo-canary-decode-files.py b/python-api-examples/offline-nemo-canary-decode-files.py new file mode 100644 index 00000000..f4ab2025 --- /dev/null +++ b/python-api-examples/offline-nemo-canary-decode-files.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming Canary model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model supports 4 languages and it is converted from +https://huggingface.co/nvidia/canary-180m-flash + +It supports automatic speech-to-text recognition (ASR) in 4 languages +(English, German, French, Spanish) and translation from English to +German/French/Spanish and from German/French/Spanish to English with or +without punctuation and capitalization (PnC). +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/encoder.int8.onnx" + decoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/decoder.int8.onnx" + tokens = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/tokens.txt" + + en_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/en.wav" + de_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/de.wav" + + if not Path(encoder).is_file() or not Path(en_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_nemo_canary( + encoder=encoder, + decoder=decoder, + tokens=tokens, + debug=True, + ), + en_wav, + de_wav, + ) + + +def decode(recognizer, samples, sample_rate, src_lang, tgt_lang): + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, samples) + + recognizer.recognizer.set_config( + config=sherpa_onnx.OfflineRecognizerConfig( + model_config=sherpa_onnx.OfflineModelConfig( + canary=sherpa_onnx.OfflineCanaryModelConfig( + src_lang=src_lang, + tgt_lang=tgt_lang, + ) + ) + ) + ) + + recognizer.decode_stream(stream) + return stream.result.text + + +def main(): + recognizer, en_wav, de_wav = create_recognizer() + + en_audio, en_sample_rate = sf.read(en_wav, dtype="float32", always_2d=True) + en_audio = en_audio[:, 0] # only use the first channel + + de_audio, de_sample_rate = sf.read(de_wav, dtype="float32", always_2d=True) + de_audio = de_audio[:, 0] # only use the first channel + + en_wav_en_result = decode( + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="en" + ) + en_wav_es_result = decode( + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="es" + ) + en_wav_de_result = decode( + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="de" + ) + en_wav_fr_result = decode( + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="fr" + ) + + de_wav_en_result = decode( + recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="en" + ) + de_wav_de_result = decode( + recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="de" + ) + + print("en_wav_en_result", en_wav_en_result) + print("en_wav_es_result", en_wav_es_result) + print("en_wav_de_result", en_wav_de_result) + print("en_wav_fr_result", en_wav_fr_result) + print("-" * 10) + print("de_wav_en_result", de_wav_en_result) + print("de_wav_de_result", de_wav_de_result) + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/canary/export_onnx_180m_flash.py b/scripts/nemo/canary/export_onnx_180m_flash.py index 416acd5b..aeba4e82 100755 --- a/scripts/nemo/canary/export_onnx_180m_flash.py +++ b/scripts/nemo/canary/export_onnx_180m_flash.py @@ -281,9 +281,14 @@ def export_decoder(canary_model): def export_tokens(canary_model): + underline = "▁" with open("./tokens.txt", "w", encoding="utf-8") as f: for i in range(canary_model.tokenizer.vocab_size): s = canary_model.tokenizer.ids_to_text([i]) + + if s[0] == " ": + s = underline + s[1:] + f.write(f"{s} {i}\n") print("Saved to tokens.txt") diff --git a/scripts/nemo/canary/test_180m_flash.py b/scripts/nemo/canary/test_180m_flash.py index 9331f41a..8654d13d 100755 --- a/scripts/nemo/canary/test_180m_flash.py +++ b/scripts/nemo/canary/test_180m_flash.py @@ -289,7 +289,13 @@ def main(): tokens.append(t) print("len(tokens)", len(tokens)) print("tokens", tokens) + text = "".join([id2token[i] for i in tokens]) + + underline = "▁" + # underline = b"\xe2\x96\x81".decode() + + text = text.replace(underline, " ").strip() print("text:", text) diff --git a/sherpa-onnx/c-api/cxx-api.cc b/sherpa-onnx/c-api/cxx-api.cc index 866dec0b..ec4e9e5d 100644 --- a/sherpa-onnx/c-api/cxx-api.cc +++ b/sherpa-onnx/c-api/cxx-api.cc @@ -5,6 +5,7 @@ #include #include +#include namespace sherpa_onnx::cxx { diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index a1e81003..37d1d869 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -25,6 +25,8 @@ set(sources jieba.cc keyword-spotter-impl.cc keyword-spotter.cc + offline-canary-model-config.cc + offline-canary-model.cc offline-ctc-fst-decoder-config.cc offline-ctc-fst-decoder.cc offline-ctc-greedy-search-decoder.cc @@ -50,7 +52,6 @@ set(sources offline-rnn-lm.cc offline-sense-voice-model-config.cc offline-sense-voice-model.cc - offline-source-separation-impl.cc offline-source-separation-model-config.cc offline-source-separation-spleeter-model-config.cc @@ -58,7 +59,6 @@ set(sources offline-source-separation-uvr-model-config.cc offline-source-separation-uvr-model.cc offline-source-separation.cc - offline-stream.cc offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc diff --git a/sherpa-onnx/csrc/offline-canary-model-config.cc b/sherpa-onnx/csrc/offline-canary-model-config.cc new file mode 100644 index 00000000..2821c10d --- /dev/null +++ b/sherpa-onnx/csrc/offline-canary-model-config.cc @@ -0,0 +1,86 @@ +// sherpa-onnx/csrc/offline-canary-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-canary-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineCanaryModelConfig::Register(ParseOptions *po) { + po->Register("canary-encoder", &encoder, + "Path to onnx encoder of Canary, e.g., encoder.int8.onnx"); + + po->Register("canary-decoder", &decoder, + "Path to onnx decoder of Canary, e.g., decoder.int8.onnx"); + + po->Register("canary-src-lang", &src_lang, + "Valid values: en, de, es, fr. If empty, default to use en"); + + po->Register("canary-tgt-lang", &tgt_lang, + "Valid values: en, de, es, fr. If empty, default to use en"); + + po->Register("canary-use-pnc", &use_pnc, + "true to enable punctuations and casing. false to disable them"); +} + +bool OfflineCanaryModelConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --canary-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --canary-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist", + decoder.c_str()); + return false; + } + + if (!src_lang.empty()) { + if (src_lang != "en" && src_lang != "de" && src_lang != "es" && + src_lang != "fr") { + SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang"); + return false; + } + } + + if (!tgt_lang.empty()) { + if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" && + tgt_lang != "fr") { + SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang"); + return false; + } + } + + return true; +} + +std::string OfflineCanaryModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineCanaryModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "src_lang=\"" << src_lang << "\", "; + os << "tgt_lang=\"" << tgt_lang << "\", "; + os << "use_pnc=" << (use_pnc ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-canary-model-config.h b/sherpa-onnx/csrc/offline-canary-model-config.h new file mode 100644 index 00000000..e57d2ee4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-canary-model-config.h @@ -0,0 +1,47 @@ +// sherpa-onnx/csrc/offline-canary-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineCanaryModelConfig { + std::string encoder; + std::string decoder; + + // en, de, es, fr, or leave it empty to use en + std::string src_lang; + + // en, de, es, fr, or leave it empty to use en + std::string tgt_lang; + + // true to enable punctuations and casing + // false to disable punctuations and casing + bool use_pnc = true; + + OfflineCanaryModelConfig() = default; + OfflineCanaryModelConfig(const std::string &encoder, + const std::string &decoder, + const std::string &src_lang, + const std::string &tgt_lang, bool use_pnc) + : encoder(encoder), + decoder(decoder), + src_lang(src_lang), + tgt_lang(tgt_lang), + use_pnc(use_pnc) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-canary-model-meta-data.h b/sherpa-onnx/csrc/offline-canary-model-meta-data.h new file mode 100644 index 00000000..2322a3ad --- /dev/null +++ b/sherpa-onnx/csrc/offline-canary-model-meta-data.h @@ -0,0 +1,23 @@ +// sherpa-onnx/csrc/offline-canary-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +struct OfflineCanaryModelMetaData { + int32_t vocab_size; + int32_t subsampling_factor = 8; + int32_t feat_dim = 120; + std::string normalize_type; + std::unordered_map lang2id; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-canary-model.cc b/sherpa-onnx/csrc/offline-canary-model.cc new file mode 100644 index 00000000..37471420 --- /dev/null +++ b/sherpa-onnx/csrc/offline-canary-model.cc @@ -0,0 +1,264 @@ +// sherpa-onnx/csrc/offline-canary-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-canary-model.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineCanaryModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.canary.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.canary.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.canary.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.canary.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::vector ForwardEncoder(Ort::Value features, + Ort::Value features_length) { + std::array encoder_inputs = {std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + return encoder_out; + } + + std::pair> ForwardDecoder( + Ort::Value tokens, std::vector decoder_states, + Ort::Value encoder_states, Ort::Value enc_mask) { + std::vector decoder_inputs; + decoder_inputs.reserve(3 + decoder_states.size()); + + decoder_inputs.push_back(std::move(tokens)); + for (auto &s : decoder_states) { + decoder_inputs.push_back(std::move(s)); + } + + decoder_inputs.push_back(std::move(encoder_states)); + decoder_inputs.push_back(std::move(enc_mask)); + + auto decoder_outputs = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), + decoder_inputs.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + Ort::Value logits = std::move(decoder_outputs[0]); + + std::vector output_decoder_states; + output_decoder_states.reserve(decoder_states.size()); + + int32_t i = 0; + for (auto &s : decoder_outputs) { + i += 1; + if (i == 1) { + continue; + } + output_decoder_states.push_back(std::move(s)); + } + + return {std::move(logits), std::move(output_decoder_states)}; + } + + std::vector GetInitialDecoderStates() { + std::array shape{1, 0, 1024}; + + std::vector ans; + ans.reserve(6); + for (int32_t i = 0; i < 6; ++i) { + Ort::Value state = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + + ans.push_back(std::move(state)); + } + + return ans; + } + + OrtAllocator *Allocator() { return allocator_; } + + const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; } + + OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + + if (model_type != "EncDecMultiTaskModel") { + SHERPA_ONNX_LOGE( + "Expected model type 'EncDecMultiTaskModel'. Given: '%s'", + model_type.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size"); + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type, + "normalize_type"); + SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim"); + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OfflineCanaryModelMetaData meta_; + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; +}; + +OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineCanaryModel::OfflineCanaryModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineCanaryModel::~OfflineCanaryModel() = default; + +std::vector OfflineCanaryModel::ForwardEncoder( + Ort::Value features, Ort::Value features_length) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); +} + +std::pair> +OfflineCanaryModel::ForwardDecoder(Ort::Value tokens, + std::vector decoder_states, + Ort::Value encoder_states, + Ort::Value enc_mask) const { + return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states), + std::move(encoder_states), std::move(enc_mask)); +} + +std::vector OfflineCanaryModel::GetInitialDecoderStates() const { + return impl_->GetInitialDecoderStates(); +} + +OrtAllocator *OfflineCanaryModel::Allocator() const { + return impl_->Allocator(); +} + +const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const { + return impl_->GetModelMetadata(); +} +OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() { + return impl_->GetModelMetadata(); +} + +#if __ANDROID_API__ >= 9 +template OfflineCanaryModel::OfflineCanaryModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineCanaryModel::OfflineCanaryModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-canary-model.h b/sherpa-onnx/csrc/offline-canary-model.h new file mode 100644 index 00000000..2b2a6113 --- /dev/null +++ b/sherpa-onnx/csrc/offline-canary-model.h @@ -0,0 +1,81 @@ +// sherpa-onnx/csrc/offline-canary-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +// see +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py +class OfflineCanaryModel { + public: + explicit OfflineCanaryModel(const OfflineModelConfig &config); + + template + OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineCanaryModel(); + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C) of dtype float32. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a vector containing: + * - encoder_states: A 3-D tensor of shape (N, T', encoder_dim) + * - encoder_len: A 1-D tensor of shape (N,) containing number + * of frames in `encoder_out` before padding. + * Its dtype is int64_t + * - enc_mask: A 2-D tensor of shape (N, T') with dtype bool + */ + std::vector ForwardEncoder(Ort::Value features, + Ort::Value features_length) const; + + /** Run the decoder model. + * + * @param tokens A int32 tensor of shape (N, num_tokens) + * @param decoder_states std::vector + * @param encoder_states Output from ForwardEncoder() + * @param enc_mask Output from ForwardEncoder() + * + * @return Return a pair: + * + * - logits A 3-D tensor of shape (N, num_words, vocab_size) + * - new_decoder_states: Can be used as input for ForwardDecoder() + */ + std::pair> ForwardDecoder( + Ort::Value tokens, std::vector decoder_states, + Ort::Value encoder_states, Ort::Value enc_mask) const; + + // The return value can be used as input for ForwardDecoder() + std::vector GetInitialDecoderStates() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + const OfflineCanaryModelMetaData &GetModelMetadata() const; + + OfflineCanaryModelMetaData &GetModelMetadata(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 9ab59b3b..493309fc 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { sense_voice.Register(po); moonshine.Register(po); dolphin.Register(po); + canary.Register(po); po->Register("telespeech-ctc", &telespeech_ctc, "Path to model.onnx for telespeech ctc"); @@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const { return dolphin.Validate(); } + if (!canary.encoder.empty()) { + return canary.Validate(); + } + if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", telespeech_ctc.c_str()); @@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const { os << "sense_voice=" << sense_voice.ToString() << ", "; os << "moonshine=" << moonshine.ToString() << ", "; os << "dolphin=" << dolphin.ToString() << ", "; + os << "canary=" << canary.ToString() << ", "; os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index c12a480a..8164c7f7 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -6,6 +6,7 @@ #include +#include "sherpa-onnx/csrc/offline-canary-model-config.h" #include "sherpa-onnx/csrc/offline-dolphin-model-config.h" #include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" @@ -32,6 +33,7 @@ struct OfflineModelConfig { OfflineSenseVoiceModelConfig sense_voice; OfflineMoonshineModelConfig moonshine; OfflineDolphinModelConfig dolphin; + OfflineCanaryModelConfig canary; std::string telespeech_ctc; std::string tokens; @@ -65,6 +67,7 @@ struct OfflineModelConfig { const OfflineSenseVoiceModelConfig &sense_voice, const OfflineMoonshineModelConfig &moonshine, const OfflineDolphinModelConfig &dolphin, + const OfflineCanaryModelConfig &canary, const std::string &telespeech_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type, @@ -81,6 +84,7 @@ struct OfflineModelConfig { sense_voice(sense_voice), moonshine(moonshine), dolphin(dolphin), + canary(canary), telespeech_ctc(telespeech_ctc), tokens(tokens), num_threads(num_threads), diff --git a/sherpa-onnx/csrc/offline-recognizer-canary-impl.h b/sherpa-onnx/csrc/offline-recognizer-canary-impl.h new file mode 100644 index 00000000..8744899c --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-canary-impl.h @@ -0,0 +1,261 @@ +// sherpa-onnx/csrc/offline-recognizer-canary-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-canary-model.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/utils.h" + +namespace sherpa_onnx { + +class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config_.model_config)) { + PostInit(); + } + + template + explicit OfflineRecognizerCanaryImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_( + std::make_unique(mgr, config_.model_config)) { + PostInit(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + DecodeStream(ss[i]); + } + } + + void DecodeStream(OfflineStream *s) const { + auto meta = model_->GetModelMetadata(); + auto enc_out = RunEncoder(s); + Ort::Value enc_states = std::move(enc_out[0]); + Ort::Value enc_mask = std::move(enc_out[2]); + // enc_out[1] is discarded + std::vector decoder_input = GetInitialDecoderInput(); + auto decoder_states = model_->GetInitialDecoderStates(); + Ort::Value logits{nullptr}; + + for (int32_t i = 0; i < decoder_input.size(); ++i) { + std::tie(logits, decoder_states) = + RunDecoder(decoder_input[i], i, std::move(decoder_states), + View(&enc_states), View(&enc_mask)); + } + + int32_t max_token_id = GetMaxTokenId(&logits); + int32_t eos = symbol_table_["<|endoftext|>"]; + + int32_t num_feature_frames = + enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] * + meta.subsampling_factor; + + std::vector tokens = {max_token_id}; + + // Assume 30 tokens per second. It is to avoid the following for loop + // running indefinitely. + int32_t num_tokens = + static_cast(num_feature_frames / 100.0 * 30) + 1; + + for (int32_t i = 1; i <= num_tokens; ++i) { + if (tokens.back() == eos) { + break; + } + + std::tie(logits, decoder_states) = + RunDecoder(tokens.back(), i, std::move(decoder_states), + View(&enc_states), View(&enc_mask)); + tokens.push_back(GetMaxTokenId(&logits)); + } + + // remove the last eos token + tokens.pop_back(); + + auto r = Convert(tokens); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); + + s->SetResult(r); + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + void SetConfig(const OfflineRecognizerConfig &config) override { + config_.model_config.canary.src_lang = config.model_config.canary.src_lang; + config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang; + config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc; + + // we don't change the config_ in the base class + } + + private: + OfflineRecognitionResult Convert(const std::vector &tokens) const { + OfflineRecognitionResult r; + r.tokens.reserve(tokens.size()); + + std::string text; + for (auto i : tokens) { + if (!symbol_table_.Contains(i)) { + continue; + } + + const auto &s = symbol_table_[i]; + text += s; + r.tokens.push_back(s); + } + + r.text = std::move(text); + + return r; + } + + int32_t GetMaxTokenId(Ort::Value *logits) const { + // logits is of shape (1, 1, vocab_size) + auto meta = model_->GetModelMetadata(); + const float *p_logits = logits->GetTensorData(); + + int32_t max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + meta.vocab_size))); + + return max_token_id; + } + + std::vector RunEncoder(OfflineStream *s) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = config_.feat_config.feature_dim; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + std::array shape = {1, num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t x_length_scalar = num_frames; + std::array x_length_shape = {1}; + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + return model_->ForwardEncoder(std::move(x), std::move(x_length)); + } + + std::pair> RunDecoder( + int32_t token, int32_t pos, std::vector decoder_states, + Ort::Value enc_states, Ort::Value enc_mask) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape = {1, 2}; + std::array _decoder_input = {token, pos}; + + Ort::Value decoder_input = Ort::Value::CreateTensor( + memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(), + shape.size()); + + return model_->ForwardDecoder(std::move(decoder_input), + std::move(decoder_states), + std::move(enc_states), std::move(enc_mask)); + } + + // see + // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242 + std::vector GetInitialDecoderInput() const { + auto canary_config = config_.model_config.canary; + const auto &meta = model_->GetModelMetadata(); + + std::vector decoder_input(9); + decoder_input[0] = symbol_table_["<|startofcontext|>"]; + decoder_input[1] = symbol_table_["<|startoftranscript|>"]; + decoder_input[2] = symbol_table_["<|emo:undefined|>"]; + + if (canary_config.src_lang.empty() || + !meta.lang2id.count(canary_config.src_lang)) { + decoder_input[3] = meta.lang2id.at("en"); + } else { + decoder_input[3] = meta.lang2id.at(canary_config.src_lang); + } + + if (canary_config.tgt_lang.empty() || + !meta.lang2id.count(canary_config.tgt_lang)) { + decoder_input[4] = meta.lang2id.at("en"); + } else { + decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang); + } + + if (canary_config.use_pnc) { + decoder_input[5] = symbol_table_["<|pnc|>"]; + } else { + decoder_input[5] = symbol_table_["<|nopnc|>"]; + } + + decoder_input[6] = symbol_table_["<|noitn|>"]; + decoder_input[7] = symbol_table_["<|notimestamp|>"]; + decoder_input[8] = symbol_table_["<|nodiarize|>"]; + + return decoder_input; + } + + private: + void PostInit() { + auto &meta = model_->GetModelMetadata(); + config_.feat_config.feature_dim = meta.feat_dim; + + config_.feat_config.nemo_normalize_type = meta.normalize_type; + + config_.feat_config.dither = 0; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.low_freq = 0; + config_.feat_config.window_type = "hann"; + config_.feat_config.is_librosa = true; + + meta.lang2id["en"] = symbol_table_["<|en|>"]; + meta.lang2id["es"] = symbol_table_["<|es|>"]; + meta.lang2id["de"] = symbol_table_["<|de|>"]; + meta.lang2id["fr"] = symbol_table_["<|fr|>"]; + + if (symbol_table_.NumSymbols() != meta.vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), meta.vocab_size); + SHERPA_ONNX_EXIT(-1); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h index 159a48b8..71a6af2d 100644 --- a/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h @@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert( r.tokens.push_back(s); } - r.text = text; + r.text = std::move(text); return r; } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 9e19dfd2..837bcc89 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -24,6 +24,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" @@ -66,6 +67,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (!config.model_config.canary.encoder.empty()) { + return std::make_unique(config); + } + // TODO(fangjun): Refactor it. We only need to use model type for the // following models: // 1. transducer and nemo_transducer @@ -252,6 +257,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (!config.model_config.canary.encoder.empty()) { + return std::make_unique(mgr, config); + } + // TODO(fangjun): Refactor it. We only need to use model type for the // following models: // 1. transducer and nemo_transducer diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index a32f92f9..d3110877 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) { return Ort::Value::CreateTensor( memory_info, v->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return Ort::Value::CreateTensor( + memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); default: fprintf(stderr, "Unsupported type: %d\n", static_cast(type_and_shape.GetElementType())); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 7d6cabb3..ac2a5468 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ set(srcs features.cc homophone-replacer.cc keyword-spotter.cc + offline-canary-model-config.cc offline-ctc-fst-decoder-config.cc offline-dolphin-model-config.cc offline-fire-red-asr-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-canary-model-config.cc b/sherpa-onnx/python/csrc/offline-canary-model-config.cc new file mode 100644 index 00000000..43da9664 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-canary-model-config.cc @@ -0,0 +1,30 @@ +// sherpa-onnx/python/csrc/offline-canary-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-canary-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-canary-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineCanaryModelConfig(py::module *m) { + using PyClass = OfflineCanaryModelConfig; + py::class_(*m, "OfflineCanaryModelConfig") + .def(py::init(), + py::arg("encoder") = "", py::arg("decoder") = "", + py::arg("src_lang") = "", py::arg("tgt_lang") = "", + py::arg("use_pnc") = true) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("src_lang", &PyClass::src_lang) + .def_readwrite("tgt_lang", &PyClass::tgt_lang) + .def_readwrite("use_pnc", &PyClass::use_pnc) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-canary-model-config.h b/sherpa-onnx/python/csrc/offline-canary-model-config.h new file mode 100644 index 00000000..9cf55519 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-canary-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-canary-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineCanaryModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index c73eafd7..dc3c65dc 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -8,6 +8,7 @@ #include #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-canary-model-config.h" #include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h" #include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" @@ -34,6 +35,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineSenseVoiceModelConfig(m); PybindOfflineMoonshineModelConfig(m); PybindOfflineDolphinModelConfig(m); + PybindOfflineCanaryModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") @@ -47,7 +49,8 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineWenetCtcModelConfig &, const OfflineSenseVoiceModelConfig &, const OfflineMoonshineModelConfig &, - const OfflineDolphinModelConfig &, const std::string &, + const OfflineDolphinModelConfig &, + const OfflineCanaryModelConfig &, const std::string &, const std::string &, int32_t, bool, const std::string &, const std::string &, const std::string &, const std::string &>(), @@ -62,8 +65,9 @@ void PybindOfflineModelConfig(py::module *m) { py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), py::arg("moonshine") = OfflineMoonshineModelConfig(), py::arg("dolphin") = OfflineDolphinModelConfig(), - py::arg("telespeech_ctc") = "", py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false, + py::arg("canary") = OfflineCanaryModelConfig(), + py::arg("telespeech_ctc") = "", py::arg("tokens") = "", + py::arg("num_threads") = 1, py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) @@ -77,6 +81,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("sense_voice", &PyClass::sense_voice) .def_readwrite("moonshine", &PyClass::moonshine) .def_readwrite("dolphin", &PyClass::dolphin) + .def_readwrite("canary", &PyClass::canary) .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index dbc68862..389b0fe1 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -19,7 +19,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { const std::string &, int32_t, const std::string &, float, float, const std::string &, const std::string &, const HomophoneReplacerConfig &>(), - py::arg("feat_config"), py::arg("model_config"), + py::arg("feat_config") = FeatureExtractorConfig(), + py::arg("model_config") = OfflineModelConfig(), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", @@ -61,6 +62,8 @@ void PybindOfflineRecognizer(py::module *m) { py::arg("hotwords"), py::call_guard()) .def("decode_stream", &PyClass::DecodeStream, py::arg("s"), py::call_guard()) + .def("set_config", &PyClass::SetConfig, py::arg("config"), + py::call_guard()) .def( "decode_streams", [](const PyClass &self, std::vector ss) { diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 705c2ad0..1215da74 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -8,9 +8,22 @@ from _sherpa_onnx import ( DenoisedAudio, FastClustering, FastClusteringConfig, + FeatureExtractorConfig, + HomophoneReplacerConfig, + OfflineCanaryModelConfig, + OfflineCtcFstDecoderConfig, + OfflineDolphinModelConfig, + OfflineFireRedAsrModelConfig, + OfflineLMConfig, + OfflineModelConfig, + OfflineMoonshineModelConfig, + OfflineNemoEncDecCtcModelConfig, + OfflineParaformerModelConfig, OfflinePunctuation, OfflinePunctuationConfig, OfflinePunctuationModelConfig, + OfflineRecognizerConfig, + OfflineSenseVoiceModelConfig, OfflineSourceSeparation, OfflineSourceSeparationConfig, OfflineSourceSeparationModelConfig, @@ -27,13 +40,18 @@ from _sherpa_onnx import ( OfflineSpeechDenoiserGtcrnModelConfig, OfflineSpeechDenoiserModelConfig, OfflineStream, + OfflineTdnnModelConfig, + OfflineTransducerModelConfig, OfflineTts, OfflineTtsConfig, OfflineTtsKokoroModelConfig, OfflineTtsMatchaModelConfig, OfflineTtsModelConfig, OfflineTtsVitsModelConfig, + OfflineWenetCtcModelConfig, + OfflineWhisperModelConfig, OfflineZipformerAudioTaggingModelConfig, + OfflineZipformerCtcModelConfig, OnlinePunctuation, OnlinePunctuationConfig, OnlinePunctuationModelConfig, diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 5a0475ec..b8586d26 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -6,6 +6,7 @@ from typing import List, Optional from _sherpa_onnx import ( FeatureExtractorConfig, HomophoneReplacerConfig, + OfflineCanaryModelConfig, OfflineCtcFstDecoderConfig, OfflineDolphinModelConfig, OfflineFireRedAsrModelConfig, @@ -425,7 +426,6 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, - model_type="nemo_ctc", ) feat_config = FeatureExtractorConfig( @@ -690,6 +690,102 @@ class OfflineRecognizer(object): self.config = recognizer_config return self + @classmethod + def from_nemo_canary( + cls, + encoder: str, + decoder: str, + tokens: str, + src_lang: str = "en", + tgt_lang: str = "en", + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 128, # not used + decoding_method: str = "greedy_search", # not used + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + hr_dict_dir: str = "", + hr_rule_fsts: str = "", + hr_lexicon: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages. + + Args: + encoder: + Path to ``encoder.onnx`` or ``encoder.int8.onnx``. + decoder: + Path to ``decoder.onnx`` or ``decoder.int8.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + src_lang: + The language of the input audio. Valid values are: en, es, de, fr. + If you leave it empty, it uses en internally. + tgt_lang: + The language of the output text. Valid values are: en, es, de, fr. + If you leave it empty, it uses en internally. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. Not used + feature_dim: + Dimension of the feature used to train the model. Not used + decoding_method: + Valid values are greedy_search. Not used + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + canary=OfflineCanaryModelConfig( + encoder=encoder, + decoder=decoder, + src_lang=src_lang, + tgt_lang=tgt_lang, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + hr=HomophoneReplacerConfig( + dict_dir=hr_dict_dir, + lexicon=hr_lexicon, + rule_fsts=hr_rule_fsts, + ), + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_whisper( cls,