// sherpa-onnx/csrc/session.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/session.h" #include #include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/provider.h" #if defined(__APPLE__) #include "coreml_provider_factory.h" // NOLINT #endif namespace sherpa_onnx { static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::string provider_str) { Provider p = StringToProvider(std::move(provider_str)); Ort::SessionOptions sess_opts; sess_opts.SetIntraOpNumThreads(num_threads); sess_opts.SetInterOpNumThreads(num_threads); switch (p) { case Provider::kCPU: break; // nothing to do for the CPU provider case Provider::kCUDA: { OrtCUDAProviderOptions options; options.device_id = 0; // set more options on need sess_opts.AppendExecutionProvider_CUDA(options); break; } case Provider::kCoreML: { #if defined(__APPLE__) uint32_t coreml_flags = 0; (void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts, coreml_flags); #else SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!"); #endif break; } } return sess_opts; } Ort::SessionOptions GetSessionOptions( const OnlineTransducerModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } } // namespace sherpa_onnx