// sherpa-onnx/csrc/onnx-utils.cc // // Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Pingfeng Luo #include "sherpa-onnx/csrc/onnx-utils.h" #include #include #include #include #include #include #include #include #include "sherpa-onnx/csrc/macros.h" #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #include "android/log.h" #endif #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { static std::string GetInputName(Ort::Session *sess, size_t index, OrtAllocator *allocator) { // Note(fangjun): We only tested 1.17.1 and 1.11.0 // For other versions, we may need to change it #if ORT_API_VERSION >= 12 auto v = sess->GetInputNameAllocated(index, allocator); return v.get(); #else auto v = sess->GetInputName(index, allocator); std::string ans = v; allocator->Free(allocator, v); return ans; #endif } static std::string GetOutputName(Ort::Session *sess, size_t index, OrtAllocator *allocator) { // Note(fangjun): We only tested 1.17.1 and 1.11.0 // For other versions, we may need to change it #if ORT_API_VERSION >= 12 auto v = sess->GetOutputNameAllocated(index, allocator); return v.get(); #else auto v = sess->GetOutputName(index, allocator); std::string ans = v; allocator->Free(allocator, v); return ans; #endif } void GetInputNames(Ort::Session *sess, std::vector *input_names, std::vector *input_names_ptr) { Ort::AllocatorWithDefaultOptions allocator; size_t node_count = sess->GetInputCount(); input_names->resize(node_count); input_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { (*input_names)[i] = GetInputName(sess, i, allocator); (*input_names_ptr)[i] = (*input_names)[i].c_str(); } } void GetOutputNames(Ort::Session *sess, std::vector *output_names, std::vector *output_names_ptr) { Ort::AllocatorWithDefaultOptions allocator; size_t node_count = sess->GetOutputCount(); output_names->resize(node_count); output_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { (*output_names)[i] = GetOutputName(sess, i, allocator); (*output_names_ptr)[i] = (*output_names)[i].c_str(); } } Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, int32_t t) { std::vector encoder_out_shape = encoder_out->GetTensorTypeAndShapeInfo().GetShape(); auto batch_size = encoder_out_shape[0]; auto num_frames = encoder_out_shape[1]; assert(t < num_frames); auto encoder_out_dim = encoder_out_shape[2]; auto offset = num_frames * encoder_out_dim; std::array shape{batch_size, encoder_out_dim}; Ort::Value ans = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); float *dst = ans.GetTensorMutableData(); const float *src = encoder_out->GetTensorData(); for (int32_t i = 0; i != batch_size; ++i) { std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); src += offset; dst += encoder_out_dim; } return ans; } void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { Ort::AllocatorWithDefaultOptions allocator; #if ORT_API_VERSION >= 12 std::vector v = meta_data.GetCustomMetadataMapKeysAllocated(allocator); for (const auto &key : v) { auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); os << key.get() << "=" << p.get() << "\n"; } #else int64_t num_keys = 0; char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys); for (int32_t i = 0; i < num_keys; ++i) { auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator); os << keys[i] << "=" << v << "\n"; allocator.Free(keys[i]); } allocator.Free(keys); #endif } Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); switch (type_and_shape.GetElementType()) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { Ort::Value ans = Ort::Value::CreateTensor( allocator, shape.data(), shape.size()); const int32_t *start = v->GetTensorData(); const int32_t *end = start + type_and_shape.GetElementCount(); int32_t *dst = ans.GetTensorMutableData(); std::copy(start, end, dst); return ans; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { Ort::Value ans = Ort::Value::CreateTensor( allocator, shape.data(), shape.size()); const int64_t *start = v->GetTensorData(); const int64_t *end = start + type_and_shape.GetElementCount(); int64_t *dst = ans.GetTensorMutableData(); std::copy(start, end, dst); return ans; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { Ort::Value ans = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); const float *start = v->GetTensorData(); const float *end = start + type_and_shape.GetElementCount(); float *dst = ans.GetTensorMutableData(); std::copy(start, end, dst); return ans; } default: fprintf(stderr, "Unsupported type: %d\n", static_cast(type_and_shape.GetElementType())); exit(-1); // unreachable code return Ort::Value{nullptr}; } } Ort::Value View(Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); switch (type_and_shape.GetElementType()) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Ort::Value::CreateTensor( memory_info, v->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Ort::Value::CreateTensor( memory_info, v->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Ort::Value::CreateTensor( memory_info, v->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); default: fprintf(stderr, "Unsupported type: %d\n", static_cast(type_and_shape.GetElementType())); exit(-1); // unreachable code return Ort::Value{nullptr}; } } float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); auto size = static_cast( std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>())); if (n != -1 && n < size && n > 0) { size = n; } const float *p = v->GetTensorData(); return std::accumulate(p, p + size, 1.0f); } float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); auto size = static_cast( std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>())); if (n != -1 && n < size && n > 0) { size = n; } auto sum = ComputeSum(v, n); return sum / size; } void PrintShape(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); std::ostringstream os; for (auto i : shape) { os << i << ", "; } os << "\n"; fprintf(stderr, "%s", os.str().c_str()); } template void Print1D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const T *d = v->GetTensorData(); std::ostringstream os; for (int32_t i = 0; i != static_cast(shape[0]); ++i) { os << d[i] << " "; } os << "\n"; fprintf(stderr, "%s\n", os.str().c_str()); } template void Print1D(const Ort::Value *v); template void Print1D(const Ort::Value *v); template void Print1D(const Ort::Value *v); template void Print2D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const T *d = v->GetTensorData(); std::ostringstream os; for (int32_t r = 0; r != static_cast(shape[0]); ++r) { for (int32_t c = 0; c != static_cast(shape[1]); ++c, ++d) { os << *d << " "; } os << "\n"; } fprintf(stderr, "%s\n", os.str().c_str()); } template void Print2D(const Ort::Value *v); template void Print2D(const Ort::Value *v); void Print3D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); for (int32_t p = 0; p != static_cast(shape[0]); ++p) { fprintf(stderr, "---plane %d---\n", p); for (int32_t r = 0; r != static_cast(shape[1]); ++r) { for (int32_t c = 0; c != static_cast(shape[2]); ++c, ++d) { fprintf(stderr, "%.3f ", *d); } fprintf(stderr, "\n"); } } fprintf(stderr, "\n"); } void Print4D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); for (int32_t p = 0; p != static_cast(shape[0]); ++p) { fprintf(stderr, "---plane %d---\n", p); for (int32_t q = 0; q != static_cast(shape[1]); ++q) { fprintf(stderr, "---subplane %d---\n", q); for (int32_t r = 0; r != static_cast(shape[2]); ++r) { for (int32_t c = 0; c != static_cast(shape[3]); ++c, ++d) { fprintf(stderr, "%.3f ", *d); } fprintf(stderr, "\n"); } fprintf(stderr, "\n"); } } fprintf(stderr, "\n"); } std::vector ReadFile(const std::string &filename) { std::ifstream input(filename, std::ios::binary); std::vector buffer(std::istreambuf_iterator(input), {}); return buffer; } #if __ANDROID_API__ >= 9 std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); if (!asset) { __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx", "Read binary file: Load %s failed", filename.c_str()); exit(-1); } auto p = reinterpret_cast(AAsset_getBuffer(asset)); size_t asset_length = AAsset_getLength(asset); std::vector buffer(p, p + asset_length); AAsset_close(asset); return buffer; } #endif #if __OHOS__ std::vector ReadFile(NativeResourceManager *mgr, const std::string &filename) { std::unique_ptr fp( OH_ResourceManager_OpenRawFile(mgr, filename.c_str()), OH_ResourceManager_CloseRawFile); if (!fp) { std::ostringstream os; os << "Read file '" << filename << "' failed."; SHERPA_ONNX_LOGE("%s", os.str().c_str()); return {}; } auto len = static_cast(OH_ResourceManager_GetRawFileSize(fp.get())); std::vector buffer(len); int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len); if (n != len) { std::ostringstream os; os << "Read file '" << filename << "' failed. Number of bytes read: " << n << ". Expected bytes to read: " << len; SHERPA_ONNX_LOGE("%s", os.str().c_str()); return {}; } return buffer; } #endif Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split) { std::vector cur_encoder_out_shape = cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); std::array ans_shape{hyps_num_split.back(), cur_encoder_out_shape[1]}; Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size()); const float *src = cur_encoder_out->GetTensorData(); float *dst = ans.GetTensorMutableData(); int32_t batch_size = static_cast(hyps_num_split.size()) - 1; for (int32_t b = 0; b != batch_size; ++b) { int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { std::copy(src, src + cur_encoder_out_shape[1], dst); dst += cur_encoder_out_shape[1]; } src += cur_encoder_out_shape[1]; } return ans; } CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) { *this = other; } CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) { if (this == &other) { return *this; } if (other.value) { Ort::AllocatorWithDefaultOptions allocator; value = Clone(allocator, &other.value); } return *this; } CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) noexcept { *this = std::move(other); } CopyableOrtValue &CopyableOrtValue::operator=( CopyableOrtValue &&other) noexcept { if (this == &other) { return *this; } value = std::move(other.value); return *this; } std::vector Convert(std::vector values) { std::vector ans; ans.reserve(values.size()); for (auto &v : values) { ans.emplace_back(std::move(v)); } return ans; } std::vector Convert(std::vector values) { std::vector ans; ans.reserve(values.size()); for (auto &v : values) { ans.emplace_back(std::move(v.value)); } return ans; } std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, const char *key, OrtAllocator *allocator) { // Note(fangjun): We only tested 1.17.1 and 1.11.0 // For other versions, we may need to change it #if ORT_API_VERSION >= 12 auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator); return v ? v.get() : ""; #else auto v = meta_data.LookupCustomMetadataMap(key, allocator); std::string ans = v ? v : ""; allocator->Free(allocator, v); return ans; #endif } } // namespace sherpa_onnx