diff --git a/sherpa-onnx/csharp-api/CPPLINT.cfg b/sherpa-onnx/csharp-api/CPPLINT.cfg deleted file mode 100644 index 51ff339c..00000000 --- a/sherpa-onnx/csharp-api/CPPLINT.cfg +++ /dev/null @@ -1 +0,0 @@ -exclude_files=.* diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bef262f9..a8efdf59 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -35,8 +35,8 @@ set(sources offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc online-conformer-transducer-model.cc - online-lm.cc online-lm-config.cc + online-lm.cc online-lstm-transducer-model.cc online-recognizer.cc online-rnn-lm.cc @@ -48,9 +48,11 @@ set(sources online-transducer-modified-beam-search-decoder.cc online-zipformer-transducer-model.cc onnx-utils.cc + session.cc packed-sequence.cc pad-sequence.cc parse-options.cc + provider.cc resample.cc slice.cc stack.cc diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index c4912abb..d9736649 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("debug", &debug, "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); } bool OfflineModelConfig::Validate() const { @@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const { os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; - os << "debug=" << (debug ? "True" : "False") << ")"; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index da17c7b5..b440c9e3 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -20,18 +20,21 @@ struct OfflineModelConfig { std::string tokens; int32_t num_threads = 2; bool debug = false; + std::string provider = "cpu"; OfflineModelConfig() = default; OfflineModelConfig(const OfflineTransducerModelConfig &transducer, const OfflineParaformerModelConfig ¶former, const OfflineNemoEncDecCtcModelConfig &nemo_ctc, - const std::string &tokens, int32_t num_threads, bool debug) + const std::string &tokens, int32_t num_threads, bool debug, + const std::string &provider) : transducer(transducer), paraformer(paraformer), nemo_ctc(nemo_ctc), tokens(tokens), num_threads(num_threads), - debug(debug) {} + debug(debug), + provider(provider) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index c981a453..ee629b4b 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -6,6 +6,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/transpose.h" @@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl { explicit Impl(const OfflineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config_.num_threads); - sess_opts_.SetInterOpNumThreads(config_.num_threads); - Init(); } diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index 3accce35..9e76e7ba 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl { explicit Impl(const OfflineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config_.num_threads); - sess_opts_.SetInterOpNumThreads(config_.num_threads); - Init(); } diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index f5d8e773..254f5e91 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" namespace sherpa_onnx { @@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl { explicit Impl(const OfflineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_WARNING), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); { auto buf = ReadFile(config.transducer.encoder_filename); InitEncoder(buf.data(), buf.size()); diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 8584f0ec..0d0ade3a 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include @@ -24,6 +23,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/unbind.h" @@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(config.encoder_filename); InitEncoder(buf.data(), buf.size()); @@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( AAssetManager *mgr, const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(mgr, config.encoder_filename); InitEncoder(buf.data(), buf.size()); @@ -185,7 +179,7 @@ std::vector> OnlineConformerTransducerModel::UnStackStates( const std::vector &states) const { const int32_t batch_size = - states[0].GetTensorTypeAndShapeInfo().GetShape()[2]; + states[0].GetTensorTypeAndShapeInfo().GetShape()[2]; assert(states.size() == 2); std::vector> ans(batch_size); @@ -209,8 +203,8 @@ std::vector OnlineConformerTransducerModel::GetEncoderInitStates() { // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 // for details constexpr int32_t kBatchSize = 1; - std::array h_shape{ - num_encoder_layers_, left_context_, kBatchSize, encoder_dim_}; + std::array h_shape{num_encoder_layers_, left_context_, kBatchSize, + encoder_dim_}; Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), h_shape.size()); @@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features, std::vector states, Ort::Value processed_frames) { std::array encoder_inputs = { - std::move(features), - std::move(states[0]), - std::move(states[1]), + std::move(features), std::move(states[0]), std::move(states[1]), std::move(processed_frames)}; auto encoder_out = encoder_sess_->Run( diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index ee40b10c..3419cfc0 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -22,6 +22,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/unbind.h" namespace sherpa_onnx { @@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(config.encoder_filename); InitEncoder(buf.data(), buf.size()); @@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( AAssetManager *mgr, const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(mgr, config.encoder_filename); InitEncoder(buf.data(), buf.size()); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index b1126cd3..06aab880 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -140,8 +139,8 @@ class OnlineRecognizer::Impl { decoder_ = std::make_unique(model_.get()); } else { - fprintf(stderr, "Unsupported decoding method: %s\n", - config.decoding_method.c_str()); + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); exit(-1); } } @@ -160,8 +159,8 @@ class OnlineRecognizer::Impl { decoder_ = std::make_unique(model_.get()); } else { - fprintf(stderr, "Unsupported decoding method: %s\n", - config.decoding_method.c_str()); + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); exit(-1); } } @@ -216,19 +215,16 @@ class OnlineRecognizer::Impl { x_shape.size()); std::array processed_frames_shape{ - static_cast(all_processed_frames.size())}; + static_cast(all_processed_frames.size())}; Ort::Value processed_frames = Ort::Value::CreateTensor( - memory_info, - all_processed_frames.data(), - all_processed_frames.size(), - processed_frames_shape.data(), - processed_frames_shape.size()); + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); auto states = model_->StackStates(states_vec); - auto pair = model_->RunEncoder( - std::move(x), std::move(states), std::move(processed_frames)); + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); decoder_->Decode(std::move(pair.first), &results); diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index 62c5d3d8..c9fc1b73 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig { std::string tokens; int32_t num_threads = 2; bool debug = false; + std::string provider = "cpu"; OnlineTransducerModelConfig() = default; OnlineTransducerModelConfig(const std::string &encoder_filename, const std::string &decoder_filename, const std::string &joiner_filename, const std::string &tokens, int32_t num_threads, - bool debug) + bool debug, const std::string &provider) : encoder_filename(encoder_filename), decoder_filename(decoder_filename), joiner_filename(joiner_filename), tokens(tokens), num_threads(num_threads), - debug(debug) {} + debug(debug), + provider(provider) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 75fb3c56..5d60a021 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -10,7 +10,6 @@ #endif #include -#include #include #include #include diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 7af95cc9..238a84d3 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -23,6 +23,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/unbind.h" @@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(config.encoder_filename); InitEncoder(buf.data(), buf.size()); @@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( AAssetManager *mgr, const OnlineTransducerModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_{}, + sess_opts_(GetSessionOptions(config)), allocator_{} { - sess_opts_.SetIntraOpNumThreads(config.num_threads); - sess_opts_.SetInterOpNumThreads(config.num_threads); - { auto buf = ReadFile(mgr, config.encoder_filename); InitEncoder(buf.data(), buf.size()); diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc new file mode 100644 index 00000000..9c50eb8c --- /dev/null +++ b/sherpa-onnx/csrc/provider.cc @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/provider.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/provider.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +Provider StringToProvider(std::string s) { + std::transform(s.cbegin(), s.cend(), s.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (s == "cpu") { + return Provider::kCPU; + } else if (s == "cuda") { + return Provider::kCUDA; + } else if (s == "coreml") { + return Provider::kCoreML; + } else { + SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); + return Provider::kCPU; + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h new file mode 100644 index 00000000..8e0dcc0a --- /dev/null +++ b/sherpa-onnx/csrc/provider.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/provider.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_PROVIDER_H_ +#define SHERPA_ONNX_CSRC_PROVIDER_H_ + +#include + +namespace sherpa_onnx { + +// Please refer to +// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java +// for a list of available providers +enum class Provider { + kCPU = 0, // CPUExecutionProvider + kCUDA = 1, // CUDAExecutionProvider + kCoreML = 2, // CoreMLExecutionProvider +}; + +/** + * Convert a string to an enum. + * + * @param s We will convert it to lowercase before comparing. + * @return Return an instance of Provider. + */ +Provider StringToProvider(std::string s); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_PROVIDER_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc new file mode 100644 index 00000000..9920ec17 --- /dev/null +++ b/sherpa-onnx/csrc/session.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/csrc/session.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/session.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/provider.h" +#if defined(__APPLE__) +#include "coreml_provider_factory.h" // NOLINT +#endif + +namespace sherpa_onnx { + +static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, + std::string provider_str) { + Provider p = StringToProvider(std::move(provider_str)); + + Ort::SessionOptions sess_opts; + sess_opts.SetIntraOpNumThreads(num_threads); + sess_opts.SetInterOpNumThreads(num_threads); + + switch (p) { + case Provider::kCPU: + break; // nothing to do for the CPU provider + case Provider::kCUDA: { + OrtCUDAProviderOptions options; + options.device_id = 0; + // set more options on need + sess_opts.AppendExecutionProvider_CUDA(options); + break; + } + case Provider::kCoreML: { +#if defined(__APPLE__) + uint32_t coreml_flags = 0; + (void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts, + coreml_flags); +#else + SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!"); +#endif + break; + } + } + + return sess_opts; +} + +Ort::SessionOptions GetSessionOptions( + const OnlineTransducerModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h new file mode 100644 index 00000000..8e0508ed --- /dev/null +++ b/sherpa-onnx/csrc/session.h @@ -0,0 +1,21 @@ +// sherpa-onnx/csrc/session.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SESSION_H_ +#define SHERPA_ONNX_CSRC_SESSION_H_ + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +Ort::SessionOptions GetSessionOptions( + const OnlineTransducerModelConfig &config); + +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/stack.cc b/sherpa-onnx/csrc/stack.cc index c7ae6bee..302ec733 100644 --- a/sherpa-onnx/csrc/stack.cc +++ b/sherpa-onnx/csrc/stack.cc @@ -6,7 +6,6 @@ #include #include -#include #include #include @@ -36,7 +35,7 @@ template Ort::Value Stack(OrtAllocator *allocator, const std::vector &values, int32_t dim) { std::vector v0_shape = - values[0]->GetTensorTypeAndShapeInfo().GetShape(); + values[0]->GetTensorTypeAndShapeInfo().GetShape(); for (int32_t i = 1; i != static_cast(values.size()); ++i) { auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); @@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator, ans_shape.reserve(v0_shape.size() + 1); ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); ans_shape.push_back(values.size()); - ans_shape.insert( - ans_shape.end(), - v0_shape.data() + dim, - v0_shape.data() + v0_shape.size()); + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim, + v0_shape.data() + v0_shape.size()); auto leading_size = static_cast(std::accumulate( - v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); - auto trailing_size = static_cast( - std::accumulate(v0_shape.begin() + dim, - v0_shape.end(), 1, - std::multiplies())); + auto trailing_size = static_cast(std::accumulate( + v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies())); - Ort::Value ans = Ort::Value::CreateTensor( - allocator, ans_shape.data(), ans_shape.size()); + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); T *dst = ans.GetTensorMutableData(); for (int32_t i = 0; i != leading_size; ++i) { @@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator, return ans; } -template Ort::Value Stack( - OrtAllocator *allocator, - const std::vector &values, - int32_t dim); +template Ort::Value Stack(OrtAllocator *allocator, + const std::vector &values, + int32_t dim); template Ort::Value Stack( - OrtAllocator *allocator, - const std::vector &values, - int32_t dim); + OrtAllocator *allocator, const std::vector &values, + int32_t dim); } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 26b561ea..48f99954 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) { .def(py::init(), + const std::string &, int32_t, bool, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(), py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false) + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-transducer-model-config.cc b/sherpa-onnx/python/csrc/online-transducer-model-config.cc index a92d0627..62c89e3e 100644 --- a/sherpa-onnx/python/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/python/csrc/online-transducer-model-config.cc @@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) { using PyClass = OnlineTransducerModelConfig; py::class_(*m, "OnlineTransducerModelConfig") .def(py::init(), + const std::string &, const std::string &, int32_t, bool, + const std::string &>(), py::arg("encoder_filename"), py::arg("decoder_filename"), py::arg("joiner_filename"), py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false) + py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu") .def_readwrite("encoder_filename", &PyClass::encoder_filename) .def_readwrite("decoder_filename", &PyClass::decoder_filename) .def_readwrite("joiner_filename", &PyClass::joiner_filename) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 1c25c7d1..5b3f2a3e 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -40,6 +40,7 @@ class OfflineRecognizer(object): feature_dim: int = 80, decoding_method: str = "greedy_search", debug: bool = False, + provider: str = "cpu", ): """ Please refer to @@ -70,6 +71,8 @@ class OfflineRecognizer(object): Support only greedy_search for now. debug: True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -81,6 +84,7 @@ class OfflineRecognizer(object): tokens=tokens, num_threads=num_threads, debug=debug, + provider=provider, ) feat_config = OfflineFeatureExtractorConfig( diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 9ea09633..c01bd2fb 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -39,6 +39,7 @@ class OnlineRecognizer(object): rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", max_active_paths: int = 4, + provider: str = "cpu", ): """ Please refer to @@ -86,6 +87,8 @@ class OnlineRecognizer(object): max_active_paths: Use only when decoding_method is modified_beam_search. It specifies the maximum number of active paths during beam search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -100,6 +103,7 @@ class OnlineRecognizer(object): joiner_filename=joiner, tokens=tokens, num_threads=num_threads, + provider=provider, ) feat_config = FeatureExtractorConfig(