// sherpa-onnx/csrc/stack.cc // // Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) #include "sherpa-onnx/csrc/stack.h" #include #include #include #include #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { static bool Compare(const std::vector &a, const std::vector &b) { if (a.size() != b.size()) return false; for (int32_t i = 0; i != static_cast(a.size()); ++i) { if (a[i] != b[i]) return false; } return true; } static void PrintShape(const std::vector &a) { for (auto i : a) { fprintf(stderr, "%d ", static_cast(i)); } fprintf(stderr, "\n"); } template Ort::Value Stack(OrtAllocator *allocator, const std::vector &values, int32_t dim) { std::vector v0_shape = values[0]->GetTensorTypeAndShapeInfo().GetShape(); for (int32_t i = 1; i != static_cast(values.size()); ++i) { auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); bool ret = Compare(v0_shape, s); if (!ret) { fprintf(stderr, "Incorrect shape in Stack !\n"); fprintf(stderr, "Shape for tensor 0: "); PrintShape(v0_shape); fprintf(stderr, "Shape for tensor %d: ", i); PrintShape(s); exit(-1); } } std::vector ans_shape; ans_shape.reserve(v0_shape.size() + 1); ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); ans_shape.push_back(values.size()); ans_shape.insert(ans_shape.end(), v0_shape.data() + dim, v0_shape.data() + v0_shape.size()); auto leading_size = static_cast(std::accumulate( v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); auto trailing_size = static_cast(std::accumulate( v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies())); Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size()); T *dst = ans.GetTensorMutableData(); for (int32_t i = 0; i != leading_size; ++i) { for (auto value : values) { const T *src = value->GetTensorData(); src += i * trailing_size; std::copy(src, src + trailing_size, dst); dst += trailing_size; } } return ans; } template Ort::Value Stack(OrtAllocator *allocator, const std::vector &values, int32_t dim); template Ort::Value Stack( OrtAllocator *allocator, const std::vector &values, int32_t dim); } // namespace sherpa_onnx