// sherpa-onnx/csrc/onnx-utils.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/onnx-utils.h" #include #include #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { void GetInputNames(Ort::Session *sess, std::vector *input_names, std::vector *input_names_ptr) { Ort::AllocatorWithDefaultOptions allocator; size_t node_count = sess->GetInputCount(); input_names->resize(node_count); input_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { auto tmp = sess->GetInputNameAllocated(i, allocator); (*input_names)[i] = tmp.get(); (*input_names_ptr)[i] = (*input_names)[i].c_str(); } } void GetOutputNames(Ort::Session *sess, std::vector *output_names, std::vector *output_names_ptr) { Ort::AllocatorWithDefaultOptions allocator; size_t node_count = sess->GetOutputCount(); output_names->resize(node_count); output_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { auto tmp = sess->GetOutputNameAllocated(i, allocator); (*output_names)[i] = tmp.get(); (*output_names_ptr)[i] = (*output_names)[i].c_str(); } } void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { Ort::AllocatorWithDefaultOptions allocator; std::vector v = meta_data.GetCustomMetadataMapKeysAllocated(allocator); for (const auto &key : v) { auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); os << key.get() << "=" << p.get() << "\n"; } } Ort::Value Clone(const Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); switch (type_and_shape.GetElementType()) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Ort::Value::CreateTensor( memory_info, const_cast(v)->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Ort::Value::CreateTensor( memory_info, const_cast(v)->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Ort::Value::CreateTensor( memory_info, const_cast(v)->GetTensorMutableData(), type_and_shape.GetElementCount(), shape.data(), shape.size()); default: fprintf(stderr, "Unsupported type: %d\n", static_cast(type_and_shape.GetElementType())); exit(-1); // unreachable code return Ort::Value{nullptr}; } } void Print1D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); for (int32_t i = 0; i != static_cast(shape[0]); ++i) { fprintf(stderr, "%.3f ", d[i]); } fprintf(stderr, "\n"); } void Print2D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); for (int32_t r = 0; r != static_cast(shape[0]); ++r) { for (int32_t c = 0; c != static_cast(shape[1]); ++c, ++d) { fprintf(stderr, "%.3f ", *d); } fprintf(stderr, "\n"); } fprintf(stderr, "\n"); } void Print3D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); for (int32_t p = 0; p != static_cast(shape[0]); ++p) { fprintf(stderr, "---plane %d---\n", p); for (int32_t r = 0; r != static_cast(shape[1]); ++r) { for (int32_t c = 0; c != static_cast(shape[2]); ++c, ++d) { fprintf(stderr, "%.3f ", *d); } fprintf(stderr, "\n"); } } fprintf(stderr, "\n"); } } // namespace sherpa_onnx