diff --git a/cmake/asio.cmake b/cmake/asio.cmake index 4941b6c1..221e10b6 100644 --- a/cmake/asio.cmake +++ b/cmake/asio.cmake @@ -18,6 +18,7 @@ function(download_asio) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(asio_URL "file://${f}") + set(asio_URL2) break() endif() endforeach() diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 949de081..4d20c258 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -18,6 +18,7 @@ function(download_googltest) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(googletest_URL "file://${f}") + set(googletest_URL2) break() endif() endforeach() diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index 6c40fce3..f70b9625 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -22,6 +22,7 @@ function(download_kaldi_native_fbank) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(kaldi_native_fbank_URL "file://${f}") + set(kaldi_native_fbank_URL2 ) break() endif() endforeach() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6143adf3..056f16d9 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -78,6 +78,7 @@ function(download_onnxruntime) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(onnxruntime_URL "file://${f}") + set(onnxruntime_URL2) break() endif() endforeach() diff --git a/cmake/portaudio.cmake b/cmake/portaudio.cmake index 78414f4d..01cb73a6 100644 --- a/cmake/portaudio.cmake +++ b/cmake/portaudio.cmake @@ -19,6 +19,7 @@ function(download_portaudio) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(portaudio_URL "file://${f}") + set(portaudio_URL2) break() endif() endforeach() diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake index b8aecdc7..b32941de 100644 --- a/cmake/pybind11.cmake +++ b/cmake/pybind11.cmake @@ -18,6 +18,7 @@ function(download_pybind11) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(pybind11_URL "file://${f}") + set(pybind11_URL2) break() endif() endforeach() diff --git a/cmake/websocketpp.cmake b/cmake/websocketpp.cmake index 2b6c38e6..35eddc90 100644 --- a/cmake/websocketpp.cmake +++ b/cmake/websocketpp.cmake @@ -19,6 +19,7 @@ function(download_websocketpp) foreach(f IN LISTS possible_file_locations) if(EXISTS ${f}) set(websocketpp_URL "file://${f}") + set(websocketpp_URL2) break() endif() endforeach() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 68eeeb9e..60a56690 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -20,6 +20,7 @@ set(sources resample.cc symbol-table.cc text-utils.cc + transpose.cc unbind.cc wave-reader.cc ) @@ -120,6 +121,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + transpose-test.cc unbind-test.cc ) diff --git a/sherpa-onnx/csrc/transpose-test.cc b/sherpa-onnx/csrc/transpose-test.cc new file mode 100644 index 00000000..98fd179b --- /dev/null +++ b/sherpa-onnx/csrc/transpose-test.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/transpose-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/transpose.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(Tranpose, Tranpose01) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2, 5}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto ans = Transpose01(allocator, &v); + auto v2 = Transpose01(allocator, &ans); + + Print3D(&v); + Print3D(&ans); + Print3D(&v2); + + const float *q = v2.GetTensorData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + EXPECT_EQ(p[i], q[i]); + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/transpose.cc b/sherpa-onnx/csrc/transpose.cc new file mode 100644 index 00000000..09a434de --- /dev/null +++ b/sherpa-onnx/csrc/transpose.cc @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/transpose.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/transpose.h" + +#include + +#include +#include + +namespace sherpa_onnx { + +template +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape.size() == 3); + + std::array ans_shape{shape[1], shape[0], shape[2]}; + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + T *dst = ans.GetTensorMutableData(); + auto plane_offset = shape[1] * shape[2]; + + for (int64_t i = 0; i != ans_shape[0]; ++i) { + const T *src = v->GetTensorData() + i * shape[2]; + for (int64_t k = 0; k != ans_shape[1]; ++k) { + std::copy(src, src + shape[2], dst); + src += plane_offset; + dst += shape[2]; + } + } + + return ans; +} + +template Ort::Value Transpose01(OrtAllocator *allocator, + const Ort::Value *v); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/transpose.h b/sherpa-onnx/csrc/transpose.h new file mode 100644 index 00000000..404064a3 --- /dev/null +++ b/sherpa-onnx/csrc/transpose.h @@ -0,0 +1,22 @@ +// sherpa-onnx/csrc/transpose.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_ +#define SHERPA_ONNX_CSRC_TRANSPOSE_H_ + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { +/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). + * + * @param allocator + * @param v A 3-D tensor of shape (B, T, C). Its dataype is T. + * + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T. + */ +template +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_