diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 650ec5ec..628cf95e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -16,6 +16,7 @@ set(sources online-transducer-modified-beam-search-decoder.cc online-zipformer-transducer-model.cc onnx-utils.cc + pad-sequence.cc parse-options.cc resample.cc slice.cc @@ -122,6 +123,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + pad-sequence-test.cc slice-test.cc transpose-test.cc unbind-test.cc diff --git a/sherpa-onnx/csrc/pad-sequence-test.cc b/sherpa-onnx/csrc/pad-sequence-test.cc new file mode 100644 index 00000000..e6bd259c --- /dev/null +++ b/sherpa-onnx/csrc/pad-sequence-test.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/csrc/pad-sequence-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/pad-sequence.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(PadSequence, ThreeTensors) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array shape1{3, 5}; + Ort::Value v1 = + Ort::Value::CreateTensor(allocator, shape1.data(), shape1.size()); + float *p1 = v1.GetTensorMutableData(); + std::iota(p1, p1 + shape1[0] * shape1[1], 0); + + std::array shape2{4, 5}; + Ort::Value v2 = + Ort::Value::CreateTensor(allocator, shape2.data(), shape2.size()); + float *p2 = v2.GetTensorMutableData(); + std::iota(p2, p2 + shape2[0] * shape2[1], 0); + + std::array shape3{2, 5}; + Ort::Value v3 = + Ort::Value::CreateTensor(allocator, shape3.data(), shape3.size()); + float *p3 = v3.GetTensorMutableData(); + std::iota(p3, p3 + shape3[0] * shape3[1], 0); + + auto ans = PadSequence(allocator, {&v1, &v2, &v3}, -1); + + Print2D(&v1); + Print2D(&v2); + Print2D(&v3); + Print3D(&ans); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/pad-sequence.cc b/sherpa-onnx/csrc/pad-sequence.cc new file mode 100644 index 00000000..d9f8ebf9 --- /dev/null +++ b/sherpa-onnx/csrc/pad-sequence.cc @@ -0,0 +1,53 @@ +// sherpa-onnx/csrc/pad-sequence.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/pad-sequence.h" + +#include + +#include +#include + +namespace sherpa_onnx { + +Ort::Value PadSequence(OrtAllocator *allocator, + const std::vector &values, + float padding_value) { + int32_t batch_size = static_cast(values.size()); + + std::vector shape0 = + values[0]->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape0.size() == 2); + + auto feature_dim = shape0[1]; + auto max_T = shape0[0]; + + for (int32_t i = 1; i != batch_size; ++i) { + auto shape = values[i]->GetTensorTypeAndShapeInfo().GetShape(); + + assert(shape.size() == 2); + assert(shape[1] == feature_dim); + + max_T = std::max(max_T, shape[0]); + } + std::array ans_shape{batch_size, max_T, feature_dim}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + float *dst = ans.GetTensorMutableData(); + std::fill(dst, dst + batch_size * max_T * feature_dim, padding_value); + + for (const auto *v : values) { + const float *src = v->GetTensorData(); + auto shape = v->GetTensorTypeAndShapeInfo().GetShape(); + std::copy(src, src + shape[0] * shape[1], dst); + dst += max_T * feature_dim; + } + + return ans; + + // TODO(fangjun): Check that the returned value is correct. +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/pad-sequence.h b/sherpa-onnx/csrc/pad-sequence.h new file mode 100644 index 00000000..d44370ca --- /dev/null +++ b/sherpa-onnx/csrc/pad-sequence.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/pad-sequence.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ +#define SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only + * batch_first=true. + * + * @param allocator + * @param values A list of 2-D tensors. Each tensor's second dimension + * must be the same and the data type of each tensor should + * be float. + * @param padding_value Value used for padding. For log-fbank, you usually use + * -23.025850929940457f as the padding value. + * + * @return Return a 3-D tensor of shape (B, max_T, C). + */ +Ort::Value PadSequence(OrtAllocator *allocator, + const std::vector &values, + float padding_value); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_