diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 628cf95e..0f0626e1 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -16,6 +16,7 @@ set(sources online-transducer-modified-beam-search-decoder.cc online-zipformer-transducer-model.cc onnx-utils.cc + packed-sequence.cc pad-sequence.cc parse-options.cc resample.cc @@ -123,6 +124,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + packed-sequence-test.cc pad-sequence-test.cc slice-test.cc transpose-test.cc diff --git a/sherpa-onnx/csrc/packed-sequence-test.cc b/sherpa-onnx/csrc/packed-sequence-test.cc new file mode 100644 index 00000000..eda38914 --- /dev/null +++ b/sherpa-onnx/csrc/packed-sequence-test.cc @@ -0,0 +1,52 @@ +// sherpa-onnx/csrc/packed-sequence-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/packed-sequence.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(PackedSequence, Case1) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{5, 5, 4}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + Ort::Value length = + Ort::Value::CreateTensor(allocator, shape.data(), 1); + int64_t *p_length = length.GetTensorMutableData(); + p_length[0] = 1; + p_length[1] = 2; + p_length[2] = 3; + p_length[3] = 5; + p_length[4] = 2; + + auto packed_seq = PackPaddedSequence(allocator, &v, &length); + fprintf(stderr, "sorted indexes: "); + for (auto i : packed_seq.sorted_indexes) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); + // output index: 0 1 2 3 4 + // sorted indexes: 3 2 1 4 0 + // length: 5 3 2 2 1 + Print3D(&v); + Print2D(&packed_seq.data); + fprintf(stderr, "batch sizes per time step: "); + for (auto i : packed_seq.batch_sizes) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); + + // TODO(fangjun): Check that the return value is correct +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/packed-sequence.cc b/sherpa-onnx/csrc/packed-sequence.cc new file mode 100644 index 00000000..09f88018 --- /dev/null +++ b/sherpa-onnx/csrc/packed-sequence.cc @@ -0,0 +1,107 @@ +// sherpa-onnx/csrc/packed-sequence.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/packed-sequence.h" + +#include + +#include +#include +#include + +#include "sherpa-onnx/csrc/slice.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +static Ort::Value IndexSelect(OrtAllocator *allocator, const Ort::Value *value, + const std::vector &sorted_indexes) { + auto shape = value->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape.size() == 3); + std::array ans_shape{static_cast(sorted_indexes.size()), + shape[1], shape[2]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + float *dst = ans.GetTensorMutableData(); + const float *src = value->GetTensorData(); + + for (auto i : sorted_indexes) { + const float *start = src + i * shape[1] * shape[2]; + std::copy(start, start + shape[1] * shape[2], dst); + dst += shape[1] * shape[2]; + } + return ans; +} + +PackedSequence PackPaddedSequence(OrtAllocator *allocator, + const Ort::Value *value, Ort::Value *length) { + std::vector v_shape = value->GetTensorTypeAndShapeInfo().GetShape(); + std::vector l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); + + assert(v_shape.size() == 3); + assert(l_shape.size() == 3); + assert(v_shape[0] == l_shape[0]); + + std::vector indexes(v_shape[0]); + std::iota(indexes.begin(), indexes.end(), 0); + + const int64_t *p_length = length->GetTensorData(); + // sort in descending order + std::sort(indexes.begin(), indexes.end(), [p_length](int32_t i, int32_t j) { + return p_length[i] > p_length[j]; + }); + + int32_t n = static_cast(v_shape[0]); + + int64_t max_T = p_length[indexes[0]]; + + int32_t sum_T = std::accumulate(p_length, p_length + n, 0); + + std::array data_shape{sum_T, v_shape[2]}; + + Ort::Value data = Ort::Value::CreateTensor( + allocator, data_shape.data(), data_shape.size()); + float *dst = data.GetTensorMutableData(); + + Ort::Value tensor = IndexSelect(allocator, value, indexes); + tensor = Transpose01(allocator, &tensor); + + // batch size at each time step + std::vector batch_sizes; + batch_sizes.reserve(max_T); + + int64_t prev_l = 0; + for (int32_t i = 0; i != n; ++i) { + auto cur_l = p_length[indexes[n - 1 - i]]; + assert(cur_l >= prev_l); + if (cur_l == prev_l) { + continue; + } + + auto cur_batch_size = n - i; + + Ort::Value cur_batch = + Slice(allocator, &tensor, prev_l, cur_l, 0, cur_batch_size); + auto count = cur_batch.GetTensorTypeAndShapeInfo().GetElementCount(); + const float *src = cur_batch.GetTensorData(); + std::copy(src, src + count, dst); + dst += count; + + for (int32_t j = prev_l; j < cur_l; ++j) { + batch_sizes.push_back(cur_batch_size); + } + + prev_l = cur_l; + } + + PackedSequence packed_seq; + packed_seq.sorted_indexes = std::move(indexes); + packed_seq.data = std::move(data); + packed_seq.batch_sizes = std::move(batch_sizes); + + return packed_seq; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/packed-sequence.h b/sherpa-onnx/csrc/packed-sequence.h new file mode 100644 index 00000000..d2d125c7 --- /dev/null +++ b/sherpa-onnx/csrc/packed-sequence.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/packed-sequence.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ +#define SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct PackedSequence { + std::vector sorted_indexes; + std::vector batch_sizes; + Ort::Value data{nullptr}; +}; + +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only + * batch_first=true. + * + * @param allocator + * @param value A 3-D tensor of shape (B, T, C). Its dtype is float. + * @param length A 1-D tensor of shape (B,). Its dtype is int64_t. Each + * element in it specifies the valid length of the corresponding + * entry in value before padding. + */ +PackedSequence PackPaddedSequence(OrtAllocator *allocator, + const Ort::Value *value, Ort::Value *length); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ diff --git a/sherpa-onnx/csrc/slice-test.cc b/sherpa-onnx/csrc/slice-test.cc index 5cf93091..6f7bde3e 100644 --- a/sherpa-onnx/csrc/slice-test.cc +++ b/sherpa-onnx/csrc/slice-test.cc @@ -13,19 +13,19 @@ namespace sherpa_onnx { TEST(Slice, Slice3D) { Ort::AllocatorWithDefaultOptions allocator; - std::array shape{3, 5, 4}; + std::array shape{5, 5, 4}; Ort::Value v = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); float *p = v.GetTensorMutableData(); std::iota(p, p + shape[0] * shape[1] * shape[2], 0); - auto v1 = Slice(&v, 0, 2, 5); - auto v2 = Slice(&v, 1, 2, 4); + auto v1 = Slice(allocator, &v, 2, 4, 0, 2); + auto v2 = Slice(allocator, &v, 1, 3, 1, 3); Print3D(&v); - Print2D(&v1); - Print2D(&v2); + Print3D(&v1); + Print3D(&v2); // TODO(fangjun): Check that the results are correct } diff --git a/sherpa-onnx/csrc/slice.cc b/sherpa-onnx/csrc/slice.cc index 47897eda..189f8517 100644 --- a/sherpa-onnx/csrc/slice.cc +++ b/sherpa-onnx/csrc/slice.cc @@ -6,29 +6,48 @@ #include +#include #include namespace sherpa_onnx { template -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, int32_t dim1_end) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); assert(shape.size() == 3); - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + assert(0 <= dim0_start); + assert(dim0_start < dim0_end); + assert(dim0_end <= shape[0]); + + assert(0 <= dim1_start); + assert(dim1_start < dim1_end); + assert(dim1_end < shape[1]); - std::array ans_shape{dim1_end - dim1_start, shape[2]}; const T *src = v->GetTensorData(); - src += dim0 * shape[1] * shape[2] + dim1_start * shape[2]; - return Ort::Value::CreateTensor(memory_info, const_cast(src), - ans_shape[0] * ans_shape[1], ans_shape.data(), - ans_shape.size()); + std::array ans_shape{dim0_end - dim0_start, dim1_end - dim1_start, + shape[2]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans.GetTensorMutableData(); + for (int32_t i = dim0_start; i != dim0_end; ++i) { + const T *src = v->GetTensorData() + i * shape[1] * shape[2]; + const T *start = src + dim1_start * shape[2]; + const T *end = src + dim1_end * shape[2]; + + std::copy(start, end, dst); + dst += ans_shape[1] * ans_shape[2]; + } + + return ans; } -template Ort::Value Slice(const Ort::Value *v, int32_t dim0, +template Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, int32_t dim1_end); } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/slice.h b/sherpa-onnx/csrc/slice.h index 08c0c82c..fb406cae 100644 --- a/sherpa-onnx/csrc/slice.h +++ b/sherpa-onnx/csrc/slice.h @@ -8,21 +8,23 @@ namespace sherpa_onnx { -/** Get a shallow copy by slicing v. +/** Get a deep copy by slicing v. * - * It returns v[dim0, dim1_start:dim1_end] + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end] * + * @param allocator * @param v A 3-D tensor. Its data type is T. - * @param dim0 Start index of the first dimension.. + * @param dim0_start Start index of the first dimension.. + * @param dim0_end End index of the first dimension.. * @param dim1_start Start index of the second dimension. * @param dim1_end End index of the second dimension. * - * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2]) - * - * @caution: The returned tensor is a shallow copy of `v`! + * @return Return a 3-D tensor of shape + * (dim0_end-dim0_start, dim1_end-dim1_start, v.shape[2]) */ template -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, int32_t dim1_end); } // namespace sherpa_onnx