// sherpa-onnx/csrc/unbind.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/unbind.h" #include #include #include #include #include #include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { template std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, int32_t dim) { std::vector shape = value->GetTensorTypeAndShapeInfo().GetShape(); assert(dim >= 0); assert(dim < static_cast(shape.size())); int32_t n = static_cast(shape[dim]); if (n == 1) { std::vector ans; ans.push_back(Clone(allocator, value)); return ans; } std::vector ans_shape = shape; ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1 // allocator tensors std::vector ans; ans.reserve(n); for (int32_t i = 0; i != n; ++i) { Ort::Value t = Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size()); ans.push_back(std::move(t)); } auto leading_size = static_cast(std::accumulate( shape.begin(), shape.begin() + dim, 1, std::multiplies())); auto trailing_size = static_cast(std::accumulate( shape.begin() + dim + 1, shape.end(), 1, std::multiplies())); const T *src = value->GetTensorData(); for (int32_t i = 0; i != leading_size; ++i) { for (int32_t k = 0; k != n; ++k) { T *dst = ans[k].GetTensorMutableData() + i * trailing_size; std::copy(src, src + trailing_size, dst); src += trailing_size; } } return ans; } template std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, int32_t dim); template std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, int32_t dim); } // namespace sherpa_onnx