61 lines
1.6 KiB
C++
61 lines
1.6 KiB
C++
// sherpa-onnx/csrc/session.cc
|
|
//
|
|
// Copyright (c) 2023 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/session.h"
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
#include "sherpa-onnx/csrc/provider.h"
|
|
#if defined(__APPLE__)
|
|
#include "coreml_provider_factory.h" // NOLINT
|
|
#endif
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
|
std::string provider_str) {
|
|
Provider p = StringToProvider(std::move(provider_str));
|
|
|
|
Ort::SessionOptions sess_opts;
|
|
sess_opts.SetIntraOpNumThreads(num_threads);
|
|
sess_opts.SetInterOpNumThreads(num_threads);
|
|
|
|
switch (p) {
|
|
case Provider::kCPU:
|
|
break; // nothing to do for the CPU provider
|
|
case Provider::kCUDA: {
|
|
OrtCUDAProviderOptions options;
|
|
options.device_id = 0;
|
|
// set more options on need
|
|
sess_opts.AppendExecutionProvider_CUDA(options);
|
|
break;
|
|
}
|
|
case Provider::kCoreML: {
|
|
#if defined(__APPLE__)
|
|
uint32_t coreml_flags = 0;
|
|
(void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts,
|
|
coreml_flags);
|
|
#else
|
|
SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!");
|
|
#endif
|
|
break;
|
|
}
|
|
}
|
|
|
|
return sess_opts;
|
|
}
|
|
|
|
Ort::SessionOptions GetSessionOptions(
|
|
const OnlineTransducerModelConfig &config) {
|
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
|
}
|
|
|
|
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|