Add transpose (#82)
This commit is contained in:
@@ -18,6 +18,7 @@ function(download_asio)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(asio_URL "file://${f}")
|
set(asio_URL "file://${f}")
|
||||||
|
set(asio_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ function(download_googltest)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(googletest_URL "file://${f}")
|
set(googletest_URL "file://${f}")
|
||||||
|
set(googletest_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ function(download_kaldi_native_fbank)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(kaldi_native_fbank_URL "file://${f}")
|
set(kaldi_native_fbank_URL "file://${f}")
|
||||||
|
set(kaldi_native_fbank_URL2 )
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ function(download_onnxruntime)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(onnxruntime_URL "file://${f}")
|
set(onnxruntime_URL "file://${f}")
|
||||||
|
set(onnxruntime_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ function(download_portaudio)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(portaudio_URL "file://${f}")
|
set(portaudio_URL "file://${f}")
|
||||||
|
set(portaudio_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ function(download_pybind11)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(pybind11_URL "file://${f}")
|
set(pybind11_URL "file://${f}")
|
||||||
|
set(pybind11_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ function(download_websocketpp)
|
|||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
if(EXISTS ${f})
|
if(EXISTS ${f})
|
||||||
set(websocketpp_URL "file://${f}")
|
set(websocketpp_URL "file://${f}")
|
||||||
|
set(websocketpp_URL2)
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ set(sources
|
|||||||
resample.cc
|
resample.cc
|
||||||
symbol-table.cc
|
symbol-table.cc
|
||||||
text-utils.cc
|
text-utils.cc
|
||||||
|
transpose.cc
|
||||||
unbind.cc
|
unbind.cc
|
||||||
wave-reader.cc
|
wave-reader.cc
|
||||||
)
|
)
|
||||||
@@ -120,6 +121,7 @@ endif()
|
|||||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||||
set(sherpa_onnx_test_srcs
|
set(sherpa_onnx_test_srcs
|
||||||
cat-test.cc
|
cat-test.cc
|
||||||
|
transpose-test.cc
|
||||||
unbind-test.cc
|
unbind-test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
38
sherpa-onnx/csrc/transpose-test.cc
Normal file
38
sherpa-onnx/csrc/transpose-test.cc
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// sherpa-onnx/csrc/transpose-test.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
TEST(Tranpose, Tranpose01) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
std::array<int64_t, 3> shape{3, 2, 5};
|
||||||
|
Ort::Value v =
|
||||||
|
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||||
|
float *p = v.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
|
||||||
|
|
||||||
|
auto ans = Transpose01(allocator, &v);
|
||||||
|
auto v2 = Transpose01(allocator, &ans);
|
||||||
|
|
||||||
|
Print3D(&v);
|
||||||
|
Print3D(&ans);
|
||||||
|
Print3D(&v2);
|
||||||
|
|
||||||
|
const float *q = v2.GetTensorData<float>();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||||
|
++i) {
|
||||||
|
EXPECT_EQ(p[i], q[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
41
sherpa-onnx/csrc/transpose.cc
Normal file
41
sherpa-onnx/csrc/transpose.cc
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
// sherpa-onnx/csrc/transpose.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
template <typename T /*=float*/>
|
||||||
|
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
||||||
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
assert(shape.size() == 3);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
|
||||||
|
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||||
|
ans_shape.size());
|
||||||
|
|
||||||
|
T *dst = ans.GetTensorMutableData<T>();
|
||||||
|
auto plane_offset = shape[1] * shape[2];
|
||||||
|
|
||||||
|
for (int64_t i = 0; i != ans_shape[0]; ++i) {
|
||||||
|
const T *src = v->GetTensorData<T>() + i * shape[2];
|
||||||
|
for (int64_t k = 0; k != ans_shape[1]; ++k) {
|
||||||
|
std::copy(src, src + shape[2], dst);
|
||||||
|
src += plane_offset;
|
||||||
|
dst += shape[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
|
||||||
|
const Ort::Value *v);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
22
sherpa-onnx/csrc/transpose.h
Normal file
22
sherpa-onnx/csrc/transpose.h
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
// sherpa-onnx/csrc/transpose.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
|
||||||
|
*
|
||||||
|
* @param allocator
|
||||||
|
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
|
||||||
|
*
|
||||||
|
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
|
||||||
|
*/
|
||||||
|
template <typename T = float>
|
||||||
|
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||||
Reference in New Issue
Block a user