// sherpa-onnx/csrc/onnx-utils.h // // Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Pingfeng Luo #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ #ifdef _MSC_VER // For ToWide() below #include #include #endif #include #include #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { /** * Get the input names of a model. * * @param sess An onnxruntime session. * @param input_names. On return, it contains the input names of the model. * @param input_names_ptr. On return, input_names_ptr[i] contains * input_names[i].c_str() */ void GetInputNames(Ort::Session *sess, std::vector *input_names, std::vector *input_names_ptr); /** * Get the output names of a model. * * @param sess An onnxruntime session. * @param output_names. On return, it contains the output names of the model. * @param output_names_ptr. On return, output_names_ptr[i] contains * output_names[i].c_str() */ void GetOutputNames(Ort::Session *sess, std::vector *output_names, std::vector *output_names_ptr); /** * Get the output frame of Encoder * * @param allocator allocator of onnxruntime * @param encoder_out encoder out tensor * @param t frame_index * */ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, int32_t t); void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT // Return a deep copy of v Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); // Return a shallow copy Ort::Value View(Ort::Value *v); // Print a 1-D tensor to stderr void Print1D(Ort::Value *v); // Print a 2-D tensor to stderr template void Print2D(Ort::Value *v); // Print a 3-D tensor to stderr void Print3D(Ort::Value *v); // Print a 4-D tensor to stderr void Print4D(Ort::Value *v); template void Fill(Ort::Value *tensor, T value) { auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); auto p = tensor->GetTensorMutableData(); std::fill(p, p + n, value); } std::vector ReadFile(const std::string &filename); #if __ANDROID_API__ >= 9 std::vector ReadFile(AAssetManager *mgr, const std::string &filename); #endif // TODO(fangjun): Document it Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split); struct CopyableOrtValue { Ort::Value value{nullptr}; CopyableOrtValue() = default; /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT : value(std::move(v)) {} CopyableOrtValue(const CopyableOrtValue &other); CopyableOrtValue &operator=(const CopyableOrtValue &other); CopyableOrtValue(CopyableOrtValue &&other) noexcept; CopyableOrtValue &operator=(CopyableOrtValue &&other) noexcept; }; std::vector Convert(std::vector values); std::vector Convert(std::vector values); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_