Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)

Thanks to @Peakyxh for providing pre-built onnxruntime libraries 
with CUDA support for Linux aarch64.

Tested on Jetson nano b01
This commit is contained in:
Fangjun Kuang
2024-11-01 11:16:28 +08:00
committed by GitHub
parent a3c89aa0d8
commit 9ab89c33bc
41 changed files with 537 additions and 291 deletions

View File

@@ -7,6 +7,8 @@
#include <stdio.h>
#include <stdlib.h>
#include <utility>
#if __ANDROID_API__ >= 8
#include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \
@@ -36,30 +38,28 @@
#endif
// Read an integer
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = atoi(value.get()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = atoi(value.c_str()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
dst = default_value; \
} else { \
dst = atoi(value.get()); \
dst = atoi(value.c_str()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
@@ -68,118 +68,111 @@
} while (0)
// read a vector of integers
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of floats
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.c_str(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of strings separated by sep
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), sep, false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.c_str(), sep, false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
dst = default_value; \
} else { \
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} \
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
dst = default_value; \
} else { \
dst = std::move(value); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} \
} while (0)
#define SHERPA_ONNX_EXIT(code) exit(code)

View File

@@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
return std::move(ans[0]);
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
return meta_data_;

View File

@@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"If you are using models from NeMo, please refer to\n"
@@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
if (model_type == "EncDecCTCModelBPE") {
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("EncDecCTCModel")) {
} else if (model_type == "EncDecCTCModel") {
return ModelType::kEncDecCTCModel;
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
} else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
return ModelType::kEncDecHybridRNNTCTCBPEModel;
} else if (model_type.get() == std::string("tdnn")) {
} else if (model_type == "tdnn") {
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
} else if (model_type == "zipformer2_ctc") {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
} else if (model_type == "wenet_ctc") {
return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) {
} else if (model_type == "telespeech_ctc") {
return ModelType::kTeleSpeechCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}

View File

@@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
return {std::move(cached_decoder_out[0]), std::move(next_states)};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void InitPreprocessor(void *model_data, size_t model_data_length) {

View File

@@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }

View File

@@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
auto model_type =
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
@@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
@@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
auto model_type =
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
@@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {

View File

@@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
return meta_data_;
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
int32_t ContextSize() const { return context_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const {
int32_t end_index) {
assert(end_index <= results.size());
int32_t batch_size = end_index;
@@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
}
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const {
int32_t end_index) {
assert(end_index <= results.size());
int32_t batch_size = end_index;

View File

@@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
return std::move(logit[0]);
}
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
s0_shape.size());
@@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }

View File

@@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }

View File

@@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {

View File

@@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
return {std::move(ans[0]), std::move(ans[1])};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
return meta_data_;

View File

@@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
conv_vec[i] = &states[i][1];
}
Ort::Value attn = Cat(allocator_, attn_vec, 2);
Ort::Value conv = Cat(allocator_, conv_vec, 2);
auto allocator =
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
Ort::Value attn = Cat(allocator, attn_vec, 2);
Ort::Value conv = Cat(allocator, conv_vec, 2);
std::vector<Ort::Value> ans;
ans.reserve(2);
@@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
auto allocator =
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
assert(attn_vec.size() == batch_size);
assert(conv_vec.size() == batch_size);

View File

@@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
h_buf[i] = &states[i][0];
c_buf[i] = &states[i][1];
}
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
Ort::Value h = Cat(allocator_, h_buf, 1);
Ort::Value c = Cat(allocator_, c_buf, 1);
Ort::Value h = Cat(allocator, h_buf, 1);
Ort::Value c = Cat(allocator, c_buf, 1);
std::vector<Ort::Value> ans;
ans.reserve(2);
@@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
assert(h_vec.size() == batch_size);
assert(c_vec.size() == batch_size);

View File

@@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
int32_t ChunkShift() const { return chunk_shift_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - cache_last_channel
@@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) {
return std::move(states[0]);
@@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
std::vector<Ort::Value> states) const {
assert(states.size() == 3);
auto allocator = const_cast<Impl *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
for (int32_t i = 0; i != 3; ++i) {
std::vector<Ort::Value> v;
if (i == 2) {
v = Unbind<int64_t>(allocator_, &states[i], 0);
v = Unbind<int64_t>(allocator, &states[i], 0);
} else {
v = Unbind(allocator_, &states[i], 0);
v = Unbind(allocator, &states[i], 0);
}
assert(v.size() == batch_size);

View File

@@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {

View File

@@ -5,10 +5,10 @@
#include "sherpa-onnx/csrc/online-rnn-lm.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
@@ -53,49 +53,49 @@ class OnlineRnnLM::Impl {
// classic rescore function
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;
for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;
for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;
if (token_num_in_chunk < 1) {
continue;
}
if (token_num_in_chunk < 1) {
continue;
}
if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}
if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos,
ys.end() - 1, p_x);
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
p_x);
// streaming forward by NN LM
auto out = ScoreToken(std::move(x),
Convert(std::move(h.nn_lm_states)));
// streaming forward by NN LM
auto out =
ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
h.cur_scored_pos += token_num_in_chunk;
}
h.cur_scored_pos += token_num_in_chunk;
}
}
}
}
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
@@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
}
// get init states for classic rescore
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
@@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
// classic rescore scores
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
std::vector<Hypotheses> *hyps) {
return impl_->ComputeLMScore(scale, context_size, hyps);
}
@@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScoreSF(scale, hyp);
}
} // namespace sherpa_onnx

View File

@@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
@@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("conformer")) {
if (model_type == "conformer") {
return ModelType::kConformer;
} else if (model_type.get() == std::string("lstm")) {
} else if (model_type == "lstm") {
return ModelType::kLstm;
} else if (model_type.get() == std::string("zipformer")) {
} else if (model_type == "zipformer") {
return ModelType::kZipformer;
} else if (model_type.get() == std::string("zipformer2")) {
} else if (model_type == "zipformer2") {
return ModelType::kZipformer2;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}

View File

@@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
@@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
std::vector<Ort::Value> ans;
auto allocator = const_cast<Impl *>(this)->allocator_;
// stack cache_last_channel
std::vector<const Ort::Value *> buf(batch_size);
@@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
Ort::Value c{nullptr};
if (i == 2) {
c = Cat<int64_t>(allocator_, buf, 0);
c = Cat<int64_t>(allocator, buf, 0);
} else {
c = Cat(allocator_, buf, 0);
c = Cat(allocator, buf, 0);
}
ans.push_back(std::move(c));
@@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<Ort::Value> states) {
assert(states.size() == 3);
std::vector<std::vector<Ort::Value>> ans;

View File

@@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
return config_.wenet_ctc.chunk_size * subsampling_factor_;
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache

View File

@@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
std::vector<Ort::Value> ans;
ans.reserve(states[0].size());
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][i];
}
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
ans.push_back(std::move(v));
}
@@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders + i];
}
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
ans.push_back(std::move(v));
}
@@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 2 + i];
}
// (num_layers, left_context_len, 1, attention_dims)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
@@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 3 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
@@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 4 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
@@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 5 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
@@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 6 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
@@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
auto v = Unbind<int64_t>(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_avg
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_key
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val2
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv1
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv2
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {

View File

@@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
int32_t ChunkShift() const { return decode_chunk_len_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache
@@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size());
std::vector<const Ort::Value *> buf(batch_size);
@@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<Ort::Value> states) {
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2);

View File

@@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
std::vector<const Ort::Value *> buf(batch_size);
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<Ort::Value> ans;
int32_t num_states = static_cast<int32_t>(states[0].size());
ans.reserve(num_states);
@@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 1];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 2];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 3];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 4];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 5];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
}
@@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 2];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
@@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 1];
}
auto v = Cat<int64_t>(allocator_, buf, 0);
auto v = Cat<int64_t>(allocator, buf, 0);
ans.push_back(std::move(v));
}
return ans;
@@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
for (int32_t i = 0; i != m; ++i) {
{
auto v = Unbind(allocator_, &states[i * 6], 1);
auto v = Unbind(allocator, &states[i * 6], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
auto v = Unbind(allocator, &states[i * 6 + 1], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
auto v = Unbind(allocator, &states[i * 6 + 2], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
auto v = Unbind(allocator, &states[i * 6 + 3], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
auto v = Unbind(allocator, &states[i * 6 + 4], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
auto v = Unbind(allocator, &states[i * 6 + 5], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
{
auto v = Unbind(allocator_, &states[m * 6], 0);
auto v = Unbind(allocator, &states[m * 6], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
@@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {

View File

@@ -21,6 +21,36 @@
namespace sherpa_onnx {
static std::string GetInputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetInputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetInputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
static std::string GetOutputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetOutputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetOutputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
@@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
input_names->resize(node_count);
input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator);
(*input_names)[i] = tmp.get();
(*input_names)[i] = GetInputName(sess, i, allocator);
(*input_names_ptr)[i] = (*input_names)[i].c_str();
}
}
@@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
output_names->resize(node_count);
output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator);
(*output_names)[i] = tmp.get();
(*output_names)[i] = GetOutputName(sess, i, allocator);
(*output_names_ptr)[i] = (*output_names)[i].c_str();
}
}
@@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
Ort::AllocatorWithDefaultOptions allocator;
#if ORT_API_VERSION >= 17
std::vector<Ort::AllocatedStringPtr> v =
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
for (const auto &key : v) {
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
os << key.get() << "=" << p.get() << "\n";
}
#else
int64_t num_keys = 0;
char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
for (int32_t i = 0; i < num_keys; ++i) {
auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
os << keys[i] << "=" << v << "\n";
allocator.Free(keys[i]);
}
allocator.Free(keys);
#endif
}
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
@@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
return ans;
}
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
} // namespace sherpa_onnx

View File

@@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t);
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key, OrtAllocator *allocator);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT

View File

@@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
case Provider::kCPU:
break; // nothing to do for the CPU provider
case Provider::kXnnpack: {
#if ORT_API_VERSION >= 17
if (std::find(available_providers.begin(), available_providers.end(),
"XnnpackExecutionProvider") != available_providers.end()) {
sess_opts.AppendExecutionProvider("XNNPACK");
@@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
os.str().c_str());
}
#else
SHERPA_ONNX_LOGE(
"Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
static_cast<int32_t>(ORT_API_VERSION));
#endif
break;
}
case Provider::kTRT: {

View File

@@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "framework", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
@@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("wespeaker")) {
if (model_type == "wespeaker") {
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
} else if (model_type == "3d-speaker") {
return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) {
} else if (model_type == "nemo") {
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}

View File

@@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
return std::move(outputs[0]);
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
return meta_data_;

View File

@@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
@@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
auto model_type_str = std::string(model_type.get());
if (model_type_str.find("whisper") == 0) {
if (model_type.find("whisper") == 0) {
return ModelType::kWhisper;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}

View File

@@ -29,20 +29,19 @@ namespace {
const char *ws = " \t\n\r\f\v";
// trim from end of string (right)
inline std::string &TrimRight(std::string &s, const char *t = ws) {
s.erase(s.find_last_not_of(t) + 1);
return s;
inline void TrimRight(std::string *s, const char *t = ws) {
s->erase(s->find_last_not_of(t) + 1);
}
// trim from beginning of string (left)
inline std::string &TrimLeft(std::string &s, const char *t = ws) {
s.erase(0, s.find_first_not_of(t));
return s;
inline void TrimLeft(std::string *s, const char *t = ws) {
s->erase(0, s->find_first_not_of(t));
}
// trim from both ends of string (right then left)
inline std::string &Trim(std::string &s, const char *t = ws) {
return TrimLeft(TrimRight(s, t), t);
inline void Trim(std::string *s, const char *t = ws) {
TrimRight(s, t);
TrimLeft(s, t);
}
} // namespace
@@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
std::string sym;
int32_t id = -1;
while (std::getline(is, line)) {
Trim(line);
Trim(&line);
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {