Add transpose (#82)
This commit is contained in:
@@ -18,6 +18,7 @@ function(download_asio)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(asio_URL "file://${f}")
|
||||
set(asio_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -18,6 +18,7 @@ function(download_googltest)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(googletest_URL "file://${f}")
|
||||
set(googletest_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -22,6 +22,7 @@ function(download_kaldi_native_fbank)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(kaldi_native_fbank_URL "file://${f}")
|
||||
set(kaldi_native_fbank_URL2 )
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -78,6 +78,7 @@ function(download_onnxruntime)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(onnxruntime_URL "file://${f}")
|
||||
set(onnxruntime_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -19,6 +19,7 @@ function(download_portaudio)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(portaudio_URL "file://${f}")
|
||||
set(portaudio_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -18,6 +18,7 @@ function(download_pybind11)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(pybind11_URL "file://${f}")
|
||||
set(pybind11_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -19,6 +19,7 @@ function(download_websocketpp)
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(websocketpp_URL "file://${f}")
|
||||
set(websocketpp_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -20,6 +20,7 @@ set(sources
|
||||
resample.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
transpose.cc
|
||||
unbind.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
@@ -120,6 +121,7 @@ endif()
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
set(sherpa_onnx_test_srcs
|
||||
cat-test.cc
|
||||
transpose-test.cc
|
||||
unbind-test.cc
|
||||
)
|
||||
|
||||
|
||||
38
sherpa-onnx/csrc/transpose-test.cc
Normal file
38
sherpa-onnx/csrc/transpose-test.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/transpose-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(Tranpose, Tranpose01) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 3> shape{3, 2, 5};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
|
||||
|
||||
auto ans = Transpose01(allocator, &v);
|
||||
auto v2 = Transpose01(allocator, &ans);
|
||||
|
||||
Print3D(&v);
|
||||
Print3D(&ans);
|
||||
Print3D(&v2);
|
||||
|
||||
const float *q = v2.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||
++i) {
|
||||
EXPECT_EQ(p[i], q[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
41
sherpa-onnx/csrc/transpose.cc
Normal file
41
sherpa-onnx/csrc/transpose.cc
Normal file
@@ -0,0 +1,41 @@
|
||||
// sherpa-onnx/csrc/transpose.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
template <typename T /*=float*/>
|
||||
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(shape.size() == 3);
|
||||
|
||||
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
auto plane_offset = shape[1] * shape[2];
|
||||
|
||||
for (int64_t i = 0; i != ans_shape[0]; ++i) {
|
||||
const T *src = v->GetTensorData<T>() + i * shape[2];
|
||||
for (int64_t k = 0; k != ans_shape[1]; ++k) {
|
||||
std::copy(src, src + shape[2], dst);
|
||||
src += plane_offset;
|
||||
dst += shape[2];
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
|
||||
const Ort::Value *v);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
22
sherpa-onnx/csrc/transpose.h
Normal file
22
sherpa-onnx/csrc/transpose.h
Normal file
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/csrc/transpose.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||
#define SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
|
||||
*
|
||||
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
|
||||
*/
|
||||
template <typename T = float>
|
||||
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||
Reference in New Issue
Block a user