diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 60a56690..650ec5ec 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -18,6 +18,7 @@ set(sources onnx-utils.cc parse-options.cc resample.cc + slice.cc symbol-table.cc text-utils.cc transpose.cc @@ -121,6 +122,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + slice-test.cc transpose-test.cc unbind-test.cc ) diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 8357bc31..8b0cf34e 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -57,9 +57,6 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, auto offset = num_frames * encoder_out_dim; - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array shape{batch_size, encoder_out_dim}; Ort::Value ans = @@ -90,9 +87,6 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - switch (type_and_shape.GetElementType()) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { Ort::Value ans = Ort::Value::CreateTensor( diff --git a/sherpa-onnx/csrc/slice-test.cc b/sherpa-onnx/csrc/slice-test.cc new file mode 100644 index 00000000..5cf93091 --- /dev/null +++ b/sherpa-onnx/csrc/slice-test.cc @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/slice-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/slice.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(Slice, Slice3D) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 5, 4}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto v1 = Slice(&v, 0, 2, 5); + auto v2 = Slice(&v, 1, 2, 4); + + Print3D(&v); + Print2D(&v1); + Print2D(&v2); + + // TODO(fangjun): Check that the results are correct +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/slice.cc b/sherpa-onnx/csrc/slice.cc new file mode 100644 index 00000000..47897eda --- /dev/null +++ b/sherpa-onnx/csrc/slice.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/slice.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/slice.h" + +#include + +#include + +namespace sherpa_onnx { + +template +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, + int32_t dim1_end) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape.size() == 3); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array ans_shape{dim1_end - dim1_start, shape[2]}; + const T *src = v->GetTensorData(); + src += dim0 * shape[1] * shape[2] + dim1_start * shape[2]; + + return Ort::Value::CreateTensor(memory_info, const_cast(src), + ans_shape[0] * ans_shape[1], ans_shape.data(), + ans_shape.size()); +} + +template Ort::Value Slice(const Ort::Value *v, int32_t dim0, + int32_t dim1_start, int32_t dim1_end); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/slice.h b/sherpa-onnx/csrc/slice.h new file mode 100644 index 00000000..08c0c82c --- /dev/null +++ b/sherpa-onnx/csrc/slice.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/slice.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SLICE_H_ +#define SHERPA_ONNX_CSRC_SLICE_H_ + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +/** Get a shallow copy by slicing v. + * + * It returns v[dim0, dim1_start:dim1_end] + * + * @param v A 3-D tensor. Its data type is T. + * @param dim0 Start index of the first dimension.. + * @param dim1_start Start index of the second dimension. + * @param dim1_end End index of the second dimension. + * + * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2]) + * + * @caution: The returned tensor is a shallow copy of `v`! + */ +template +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, + int32_t dim1_end); +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SLICE_H_