// sherpa-onnx/csrc/transpose-test.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/transpose.h" #include #include "gtest/gtest.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { TEST(Tranpose, Tranpose01) { Ort::AllocatorWithDefaultOptions allocator; std::array shape{3, 2, 5}; Ort::Value v = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); float *p = v.GetTensorMutableData(); std::iota(p, p + shape[0] * shape[1] * shape[2], 0); auto ans = Transpose01(allocator, &v); auto v2 = Transpose01(allocator, &ans); Print3D(&v); Print3D(&ans); Print3D(&v2); const float *q = v2.GetTensorData(); for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); ++i) { EXPECT_EQ(p[i], q[i]); } } TEST(Tranpose, Tranpose12) { Ort::AllocatorWithDefaultOptions allocator; std::array shape{3, 2, 5}; Ort::Value v = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); float *p = v.GetTensorMutableData(); std::iota(p, p + shape[0] * shape[1] * shape[2], 0); auto ans = Transpose12(allocator, &v); auto v2 = Transpose12(allocator, &ans); Print3D(&v); Print3D(&ans); Print3D(&v2); const float *q = v2.GetTensorData(); for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); ++i) { EXPECT_EQ(p[i], q[i]); } } } // namespace sherpa_onnx