@@ -21,8 +21,26 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// For ToWide() below
|
||||
#include <codecvt>
|
||||
#include <locale>
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// See
|
||||
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
|
||||
static std::wstring ToWide(const std::string &s) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.from_bytes(s);
|
||||
}
|
||||
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
|
||||
#else
|
||||
#define SHERPA_MAYBE_WIDE(s) s
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Get the input names of a model.
|
||||
*
|
||||
@@ -85,8 +103,8 @@ RnntModel::RnntModel(const std::string &encoder_filename,
|
||||
}
|
||||
|
||||
void RnntModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
@@ -95,8 +113,8 @@ void RnntModel::InitEncoder(const std::string &filename) {
|
||||
}
|
||||
|
||||
void RnntModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
@@ -106,8 +124,8 @@ void RnntModel::InitDecoder(const std::string &filename) {
|
||||
}
|
||||
|
||||
void RnntModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
@@ -117,8 +135,8 @@ void RnntModel::InitJoiner(const std::string &filename) {
|
||||
}
|
||||
|
||||
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
|
||||
joiner_encoder_proj_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_encoder_proj_sess_.get(),
|
||||
&joiner_encoder_proj_input_names_,
|
||||
@@ -130,8 +148,8 @@ void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
|
||||
}
|
||||
|
||||
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
|
||||
joiner_decoder_proj_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_decoder_proj_sess_.get(),
|
||||
&joiner_decoder_proj_input_names_,
|
||||
|
||||
Reference in New Issue
Block a user