Support windows (#17)

* add onnxruntime for windows
This commit is contained in:
Fangjun Kuang
2022-10-13 17:30:30 +08:00
committed by GitHub
parent c70f5625f4
commit 4614d02d6d
7 changed files with 280 additions and 148 deletions

View File

@@ -21,8 +21,26 @@
#include <utility>
#include <vector>
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
namespace sherpa_onnx {
#ifdef _MSC_VER
// See
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
static std::wstring ToWide(const std::string &s) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
#else
#define SHERPA_MAYBE_WIDE(s) s
#endif
/**
* Get the input names of a model.
*
@@ -85,8 +103,8 @@ RnntModel::RnntModel(const std::string &encoder_filename,
}
void RnntModel::InitEncoder(const std::string &filename) {
encoder_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
encoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
@@ -95,8 +113,8 @@ void RnntModel::InitEncoder(const std::string &filename) {
}
void RnntModel::InitDecoder(const std::string &filename) {
decoder_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
decoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
@@ -106,8 +124,8 @@ void RnntModel::InitDecoder(const std::string &filename) {
}
void RnntModel::InitJoiner(const std::string &filename) {
joiner_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
joiner_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
@@ -117,8 +135,8 @@ void RnntModel::InitJoiner(const std::string &filename) {
}
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
joiner_encoder_proj_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_encoder_proj_sess_.get(),
&joiner_encoder_proj_input_names_,
@@ -130,8 +148,8 @@ void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
}
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
joiner_decoder_proj_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_decoder_proj_sess_.get(),
&joiner_decoder_proj_input_names_,