From 3ea6aa949d9b48da546fb59d0f4956c0520c0483 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 21 Feb 2023 20:00:03 +0800 Subject: [PATCH] Add Streaming zipformer (#50) --- .github/workflows/run-python-test.yaml | 7 +- .gitignore | 1 + CMakeLists.txt | 1 + cmake/googletest.cmake | 16 +- sherpa-onnx/csrc/CMakeLists.txt | 33 ++ sherpa-onnx/csrc/cat-test.cc | 254 ++++++++++ sherpa-onnx/csrc/cat.cc | 106 ++++ sherpa-onnx/csrc/cat.h | 28 ++ sherpa-onnx/csrc/macros.h | 41 ++ .../csrc/online-lstm-transducer-model.cc | 105 ++-- sherpa-onnx/csrc/online-transducer-model.cc | 6 + .../csrc/online-zipformer-transducer-model.cc | 456 ++++++++++++++++++ .../csrc/online-zipformer-transducer-model.h | 99 ++++ sherpa-onnx/csrc/onnx-utils.cc | 66 ++- sherpa-onnx/csrc/onnx-utils.h | 18 +- sherpa-onnx/csrc/text-utils.cc | 30 ++ sherpa-onnx/csrc/text-utils.h | 85 ++++ sherpa-onnx/csrc/unbind-test.cc | 223 +++++++++ sherpa-onnx/csrc/unbind.cc | 72 +++ sherpa-onnx/csrc/unbind.h | 28 ++ 20 files changed, 1576 insertions(+), 99 deletions(-) create mode 100644 sherpa-onnx/csrc/cat-test.cc create mode 100644 sherpa-onnx/csrc/cat.cc create mode 100644 sherpa-onnx/csrc/cat.h create mode 100644 sherpa-onnx/csrc/macros.h create mode 100644 sherpa-onnx/csrc/online-zipformer-transducer-model.cc create mode 100644 sherpa-onnx/csrc/online-zipformer-transducer-model.h create mode 100644 sherpa-onnx/csrc/text-utils.cc create mode 100644 sherpa-onnx/csrc/text-utils.h create mode 100644 sherpa-onnx/csrc/unbind-test.cc create mode 100644 sherpa-onnx/csrc/unbind.cc create mode 100644 sherpa-onnx/csrc/unbind.h diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 07fd6f14..223db1f8 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -33,8 +33,13 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest] # windows-latest] python-version: ["3.7", "3.8", "3.9", "3.10"] + exclude: + - os: macos-latest + python-version: "3.9" + - os: macos-latest + python-version: "3.10" steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index 0c1f4139..b898c51c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ sherpa-onnx-* __pycache__ dist/ sherpa_onnx.egg-info/ +.DS_Store diff --git a/CMakeLists.txt b/CMakeLists.txt index 27e43708..4e2e16a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) enable_testing() + include(googletest) endif() add_subdirectory(sherpa-onnx) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 05b26927..cfd89aef 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -1,18 +1,3 @@ -# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com) -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - function(download_googltest) include(FetchContent) @@ -26,6 +11,7 @@ function(download_googltest) ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz /tmp/googletest-1.13.0.tar.gz + /star-fj/fangjun/download/github/googletest-1.13.0.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 45c4eef2..8895fab4 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${CMAKE_SOURCE_DIR}) add_library(sherpa-onnx-core + cat.cc features.cc online-lstm-transducer-model.cc online-recognizer.cc @@ -8,8 +9,11 @@ add_library(sherpa-onnx-core online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc online-transducer-model.cc + online-zipformer-transducer-model.cc onnx-utils.cc symbol-table.cc + text-utils.cc + unbind.cc wave-reader.cc ) @@ -27,3 +31,32 @@ endif() install(TARGETS sherpa-onnx-core DESTINATION lib) install(TARGETS sherpa-onnx DESTINATION bin) + +if(SHERPA_ONNX_ENABLE_TESTS) + set(sherpa_onnx_test_srcs + cat-test.cc + unbind-test.cc + ) + + function(sherpa_onnx_add_test source) + get_filename_component(name ${source} NAME_WE) + set(target_name ${name}) + add_executable(${target_name} "${source}") + + target_link_libraries(${target_name} + PRIVATE + gtest + gtest_main + sherpa-onnx-core + ) + + add_test(NAME "${target_name}" + COMMAND + $ + ) + endfunction() + + foreach(source IN LISTS sherpa_onnx_test_srcs) + sherpa_onnx_add_test(${source}) + endforeach() +endif() diff --git a/sherpa-onnx/csrc/cat-test.cc b/sherpa-onnx/csrc/cat-test.cc new file mode 100644 index 00000000..a49477af --- /dev/null +++ b/sherpa-onnx/csrc/cat-test.cc @@ -0,0 +1,254 @@ +// sherpa-onnx/csrc/cat-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/cat.h" + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(Cat, Test1DTensors) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{3}; + std::array b_shape{6}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); + } + + Print1D(&a); + Print1D(&b); + Print1D(&ans); +} + +TEST(Cat, Test2DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3}; + std::array b_shape{4, 3}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); + } + + Print2D(&a); + Print2D(&b); + Print2D(&ans); +} + +TEST(Cat, Test2DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{4, 3}; + std::array b_shape{4, 2}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 1); + + const float *pans = ans.GetTensorData(); + + for (int32_t r = 0; r != static_cast(a_shape[0]); ++r) { + for (int32_t i = 0; i != static_cast(a_shape[1]); + ++i, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t i = 0; i != static_cast(b_shape[1]); + ++i, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print2D(&a); + Print2D(&b); + Print2D(&ans); +} + +TEST(Cat, Test3DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3, 2}; + std::array b_shape{4, 3, 2}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +TEST(Cat, Test3DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 2, 3}; + std::array b_shape{2, 4, 3}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 1); + + const float *pans = ans.GetTensorData(); + + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[1] * a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[1] * b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +TEST(Cat, Test3DTensorsDim2) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3, 4}; + std::array b_shape{2, 3, 5}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Cat(allocator, {&a, &b}, 2); + + const float *pans = ans.GetTensorData(); + + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/cat.cc b/sherpa-onnx/csrc/cat.cc new file mode 100644 index 00000000..f1938193 --- /dev/null +++ b/sherpa-onnx/csrc/cat.cc @@ -0,0 +1,106 @@ +// sherpa-onnx/csrc/cat.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/cat.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static bool Compare(const std::vector &a, + const std::vector &b, int32_t skip_dim) { + if (a.size() != b.size()) return false; + + for (int32_t i = 0; i != static_cast(a.size()); ++i) { + if (i == skip_dim) continue; + + if (a[i] != b[i]) return false; + } + + return true; +} + +static void PrintShape(const std::vector &a) { + for (auto i : a) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); +} + +template +Ort::Value Cat(OrtAllocator *allocator, + const std::vector &values, int32_t dim) { + if (values.size() == 1u) { + return Clone(values[0]); + } + + std::vector v0_shape = + values[0]->GetTensorTypeAndShapeInfo().GetShape(); + + int64_t total_dim = v0_shape[dim]; + + for (int32_t i = 1; i != static_cast(values.size()); ++i) { + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); + total_dim += s[dim]; + + bool ret = Compare(v0_shape, s, dim); + if (!ret) { + fprintf(stderr, "Incorrect shape in Cat !\n"); + + fprintf(stderr, "Shape for tensor 0: "); + PrintShape(v0_shape); + + fprintf(stderr, "Shape for tensor %d: ", i); + PrintShape(s); + + exit(-1); + } + } + + std::vector ans_shape; + ans_shape.reserve(v0_shape.size()); + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); + ans_shape.push_back(total_dim); + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1, + v0_shape.data() + v0_shape.size()); + + auto leading_size = static_cast(std::accumulate( + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast( + std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1, + std::multiplies())); + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans.GetTensorMutableData(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (int32_t n = 0; n != static_cast(values.size()); ++n) { + auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim]; + const T *src = values[n]->GetTensorData(); + src += i * this_dim * trailing_size; + + std::copy(src, src + this_dim * trailing_size, dst); + dst += this_dim * trailing_size; + } + } + + return std::move(ans); +} + +template Ort::Value Cat(OrtAllocator *allocator, + const std::vector &values, + int32_t dim); + +template Ort::Value Cat(OrtAllocator *allocator, + const std::vector &values, + int32_t dim); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/cat.h b/sherpa-onnx/csrc/cat.h new file mode 100644 index 00000000..92ea8ba9 --- /dev/null +++ b/sherpa-onnx/csrc/cat.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/cat.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_CAT_H_ +#define SHERPA_ONNX_CSRC_CAT_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +/** Cat a list of tensors along the given dim. + * + * @param allocator Allocator to allocate space for the returned tensor + * @param values Pointer to a list of tensors. The shape of the tensor must + * be the same except on the dim to be concatenated. + * @param dim The dim along which to concatenate the input tensors + * + * @return Return the concatenated tensor + */ +template +Ort::Value Cat(OrtAllocator *allocator, + const std::vector &values, int32_t dim); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_CAT_H_ diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h new file mode 100644 index 00000000..5b97afef --- /dev/null +++ b/sherpa-onnx/csrc/macros.h @@ -0,0 +1,41 @@ + +// sherpa-onnx/csrc/macros.h +// +// Copyright 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_MACROS_H_ +#define SHERPA_ONNX_CSRC_MACROS_H_ + +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ + exit(-1); \ + } \ + \ + dst = atoi(value.get()); \ + if (dst <= 0) { \ + fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \ + exit(-1); \ + } \ + } while (0) + +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ + if (!ret) { \ + fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \ + exit(-1); \ + } \ + } while (0) + +#endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 022a3376..0e29a3c8 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -3,6 +3,8 @@ // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" +#include + #include #include #include @@ -11,23 +13,11 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" - -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ - exit(-1); \ - } \ - dst = atoi(value.get()); \ - if (dst <= 0) { \ - fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \ - exit(-1); \ - } \ - } while (0) +#include "sherpa-onnx/csrc/unbind.h" namespace sherpa_onnx { @@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { fprintf(stderr, "%s\n", os.str().c_str()); } - Ort::AllocatorWithDefaultOptions allocator; + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); SHERPA_ONNX_READ_META_DATA(T_, "T"); SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); @@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { fprintf(stderr, "%s\n", os.str().c_str()); } - Ort::AllocatorWithDefaultOptions allocator; + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); } @@ -120,37 +110,19 @@ std::vector OnlineLstmTransducerModel::StackStates( const std::vector> &states) const { int32_t batch_size = static_cast(states.size()); - std::array h_shape{num_encoder_layers_, batch_size, d_model_}; - Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), - h_shape.size()); + std::vector h_buf(batch_size); + std::vector c_buf(batch_size); - std::array c_shape{num_encoder_layers_, batch_size, - rnn_hidden_size_}; - - Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), - c_shape.size()); - - float *dst_h = h.GetTensorMutableData(); - float *dst_c = c.GetTensorMutableData(); - - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { - for (int32_t i = 0; i != batch_size; ++i) { - const float *src_h = - states[i][0].GetTensorData() + layer * d_model_; - - const float *src_c = - states[i][1].GetTensorData() + layer * rnn_hidden_size_; - - std::copy(src_h, src_h + d_model_, dst_h); - std::copy(src_c, src_c + rnn_hidden_size_, dst_c); - - dst_h += d_model_; - dst_c += rnn_hidden_size_; - } + for (int32_t i = 0; i != batch_size; ++i) { + assert(states[i].size() == 2); + h_buf[i] = &states[i][0]; + c_buf[i] = &states[i][1]; } - std::vector ans; + Ort::Value h = Cat(allocator_, h_buf, 1); + Ort::Value c = Cat(allocator_, c_buf, 1); + std::vector ans; ans.reserve(2); ans.push_back(std::move(h)); ans.push_back(std::move(c)); @@ -161,37 +133,19 @@ std::vector OnlineLstmTransducerModel::StackStates( std::vector> OnlineLstmTransducerModel::UnStackStates( const std::vector &states) const { int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + assert(states.size() == 2); std::vector> ans(batch_size); - // allocate space - std::array h_shape{num_encoder_layers_, 1, d_model_}; - std::array c_shape{num_encoder_layers_, 1, rnn_hidden_size_}; + std::vector h_vec = Unbind(allocator_, &states[0], 1); + std::vector c_vec = Unbind(allocator_, &states[1], 1); + + assert(h_vec.size() == batch_size); + assert(c_vec.size() == batch_size); for (int32_t i = 0; i != batch_size; ++i) { - Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), - h_shape.size()); - Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), - c_shape.size()); - ans[i].push_back(std::move(h)); - ans[i].push_back(std::move(c)); - } - - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { - for (int32_t i = 0; i != batch_size; ++i) { - const float *src_h = states[0].GetTensorData() + - layer * batch_size * d_model_ + i * d_model_; - const float *src_c = states[1].GetTensorData() + - layer * batch_size * rnn_hidden_size_ + - i * rnn_hidden_size_; - - float *dst_h = ans[i][0].GetTensorMutableData() + layer * d_model_; - float *dst_c = - ans[i][1].GetTensorMutableData() + layer * rnn_hidden_size_; - - std::copy(src_h, src_h + d_model_, dst_h); - std::copy(src_c, src_c + rnn_hidden_size_, dst_c); - } + ans[i].push_back(std::move(h_vec[i])); + ans[i].push_back(std::move(c_vec[i])); } return ans; @@ -206,20 +160,15 @@ std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), h_shape.size()); - std::fill(h.GetTensorMutableData(), - h.GetTensorMutableData() + - num_encoder_layers_ * kBatchSize * d_model_, - 0); + Fill(&h, 0); std::array c_shape{num_encoder_layers_, kBatchSize, rnn_hidden_size_}; + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), c_shape.size()); - std::fill(c.GetTensorMutableData(), - c.GetTensorMutableData() + - num_encoder_layers_ * kBatchSize * rnn_hidden_size_, - 0); + Fill(&c, 0); std::vector states; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 14eaf16b..22f58199 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -8,11 +8,13 @@ #include #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { enum class ModelType { kLstm, + kZipformer, kUnkown, }; @@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { if (model_type.get() == std::string("lstm")) { return ModelType::kLstm; + } else if (model_type.get() == std::string("zipformer")) { + return ModelType::kZipformer; } else { fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); return ModelType::kUnkown; @@ -53,6 +57,8 @@ std::unique_ptr OnlineTransducerModel::Create( switch (model_type) { case ModelType::kLstm: return std::make_unique(config); + case ModelType::kZipformer: + return std::make_unique(config); case ModelType::kUnkown: return nullptr; } diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc new file mode 100644 index 00000000..2a675fc0 --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -0,0 +1,456 @@ +// sherpa-onnx/csrc/online-zipformer-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + InitEncoder(config.encoder_filename); + InitDecoder(config.decoder_filename); + InitJoiner(config.joiner_filename); +} + +void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { + encoder_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + fprintf(stderr, "%s: ", name); + for (auto i : v) { + fprintf(stderr, "%d ", i); + } + fprintf(stderr, "\n"); + }; + print(encoder_dims_, "encoder_dims"); + print(attention_dims_, "attention_dims"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); + fprintf(stderr, "T: %d\n", T_); + fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_); + } +} + +void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { + decoder_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) { + joiner_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } +} + +std::vector OnlineZipformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + ans.reserve(states[0].size()); + + // cached_len + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][i]; + } + auto v = Cat(allocator_, buf, 1); // (num_layers, 1) + ans.push_back(std::move(v)); + } + + // cached_avg + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders + i]; + } + auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) + ans.push_back(std::move(v)); + } + + // cached_key + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders * 2 + i]; + } + // (num_layers, left_context_len, 1, attention_dims) + auto v = Cat(allocator_, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_val + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders * 3 + i]; + } + // (num_layers, left_context_len, 1, attention_dims/2) + auto v = Cat(allocator_, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_val2 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders * 4 + i]; + } + // (num_layers, left_context_len, 1, attention_dims/2) + auto v = Cat(allocator_, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_conv1 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders * 5 + i]; + } + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + + // cached_conv2 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_encoders * 6 + i]; + } + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + + return ans; +} + +std::vector> +OnlineZipformerTransducerModel::UnStackStates( + const std::vector &states) const { + assert(states.size() == num_encoder_layers_.size() * 7); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + std::vector> ans; + ans.resize(batch_size); + + // cached_len + for (int32_t i = 0; i != num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_avg + for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_key + for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_val + for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_val2 + for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_conv1 + for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_conv2 + for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { + auto v = Unbind(allocator_, &states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + +std::vector OnlineZipformerTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673 + // for details + + int32_t n = static_cast(encoder_dims_.size()); + std::vector cached_len_vec; + std::vector cached_avg_vec; + std::vector cached_key_vec; + std::vector cached_val_vec; + std::vector cached_val2_vec; + std::vector cached_conv1_vec; + std::vector cached_conv2_vec; + + cached_len_vec.reserve(n); + cached_avg_vec.reserve(n); + cached_key_vec.reserve(n); + cached_val_vec.reserve(n); + cached_val2_vec.reserve(n); + cached_conv1_vec.reserve(n); + cached_conv2_vec.reserve(n); + + for (int32_t i = 0; i != n; ++i) { + { + std::array s{num_encoder_layers_[i], 1}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_len_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i]}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_avg_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i]}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_key_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i] / 2}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_val_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i] / 2}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_val2_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i], + cnn_module_kernels_[i] - 1}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_conv1_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i], + cnn_module_kernels_[i] - 1}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + cached_conv2_vec.push_back(std::move(v)); + } + } + + std::vector ans; + ans.reserve(n * 7); + + for (auto &v : cached_len_vec) ans.push_back(std::move(v)); + for (auto &v : cached_avg_vec) ans.push_back(std::move(v)); + for (auto &v : cached_key_vec) ans.push_back(std::move(v)); + for (auto &v : cached_val_vec) ans.push_back(std::move(v)); + for (auto &v : cached_val2_vec) ans.push_back(std::move(v)); + for (auto &v : cached_conv1_vec) ans.push_back(std::move(v)); + for (auto &v : cached_conv2_vec) ans.push_back(std::move(v)); + + return ans; +} + +std::pair> +OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, + std::vector states) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput( + const std::vector &results) { + int32_t batch_size = static_cast(results.size()); + std::array shape{batch_size, context_size_}; + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); + + for (const auto &r : results) { + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; + const int64_t *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size_; + } + + return decoder_input; +} + +Ort::Value OnlineZipformerTransducerModel::RunDecoder( + Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + +Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h new file mode 100644 index 00000000..779ac288 --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -0,0 +1,99 @@ +// sherpa-onnx/csrc/online-zipformer-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineZipformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineZipformerTransducerModel( + const OnlineTransducerModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + Ort::Value features, std::vector states) override; + + Ort::Value BuildDecoderInput( + const std::vector &results) override; + + Ort::Value RunDecoder(Ort::Value decoder_input) override; + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + OrtAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(const std::string &encoder_filename); + void InitDecoder(const std::string &decoder_filename); + void InitJoiner(const std::string &joiner_filename); + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineTransducerModelConfig config_; + + std::vector encoder_dims_; + std::vector attention_dims_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 8a9cf301..fd23552d 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { } } -Ort::Value Clone(Ort::Value *v) { +Ort::Value Clone(const Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData(), - type_and_shape.GetElementCount(), - shape.data(), shape.size()); + switch (type_and_shape.GetElementType()) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return Ort::Value::CreateTensor( + memory_info, + const_cast(v)->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return Ort::Value::CreateTensor( + memory_info, + const_cast(v)->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return Ort::Value::CreateTensor( + memory_info, + const_cast(v)->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + default: + fprintf(stderr, "Unsupported type: %d\n", + static_cast(type_and_shape.GetElementType())); + exit(-1); + // unreachable code + return Ort::Value{nullptr}; + } +} + +void Print1D(Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + const float *d = v->GetTensorData(); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + fprintf(stderr, "%.3f ", d[i]); + } + fprintf(stderr, "\n"); +} + +void Print2D(Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + const float *d = v->GetTensorData(); + + for (int32_t r = 0; r != static_cast(shape[0]); ++r) { + for (int32_t c = 0; c != static_cast(shape[1]); ++c, ++d) { + fprintf(stderr, "%.3f ", *d); + } + fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); +} + +void Print3D(Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + const float *d = v->GetTensorData(); + + for (int32_t p = 0; p != static_cast(shape[0]); ++p) { + fprintf(stderr, "---plane %d---\n", p); + for (int32_t r = 0; r != static_cast(shape[1]); ++r) { + for (int32_t c = 0; c != static_cast(shape[2]); ++c, ++d) { + fprintf(stderr, "%.3f ", *d); + } + fprintf(stderr, "\n"); + } + } + fprintf(stderr, "\n"); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 38a7b143..8efcefd1 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT // Return a shallow copy of v -Ort::Value Clone(Ort::Value *v); +Ort::Value Clone(const Ort::Value *v); + +// Print a 1-D tensor to stderr +void Print1D(Ort::Value *v); + +// Print a 2-D tensor to stderr +void Print2D(Ort::Value *v); + +// Print a 3-D tensor to stderr +void Print3D(Ort::Value *v); + +template +void Fill(Ort::Value *tensor, T value) { + auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); + auto p = tensor->GetTensorMutableData(); + std::fill(p, p + n, value); +} } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc new file mode 100644 index 00000000..6e98e8c1 --- /dev/null +++ b/sherpa-onnx/csrc/text-utils.cc @@ -0,0 +1,30 @@ +// sherpa-onnx/csrc/text-utils.cc +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/text-utils.h" + +#include +#include + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc + +namespace sherpa_onnx { + +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out) { + size_t start = 0, found = 0, end = full.size(); + out->clear(); + while (found != std::string::npos) { + found = full.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_strings || (found != start && start != end)) + out->push_back(full.substr(start, found - start)); + start = found + 1; + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h new file mode 100644 index 00000000..1f7d544e --- /dev/null +++ b/sherpa-onnx/csrc/text-utils.h @@ -0,0 +1,85 @@ +// sherpa-onnx/csrc/text-utils.h +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_ +#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ +#include + +#include +#include + +#ifdef _MSC_VER +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ + _strtoi64(cur_cstr, end_cstr, 10); +#else +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h + +namespace sherpa_onnx { + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +/** + \brief Split a string (e.g. 1:2:3) into a vector of integers. + + \param [in] delim String containing a list of characters, any of which + is allowed as a delimiter. + \param [in] omit_empty_strings If true, empty strings between delimiters are + allowed and will not produce an output integer; if false, + instances of characters in 'delim' that are consecutive or + at the start or end of the string would be an error. + You'll normally want this to be true if 'delim' consists + of spaces, and false otherwise. + \param [out] out The output list of integers. +*/ +template +bool SplitStringToIntegers(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector *out) { + static_assert(std::is_integral::value, ""); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + int64_t j = 0; + j = SHERPA_ONNX_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast(j); + if (static_cast(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ diff --git a/sherpa-onnx/csrc/unbind-test.cc b/sherpa-onnx/csrc/unbind-test.cc new file mode 100644 index 00000000..8159685b --- /dev/null +++ b/sherpa-onnx/csrc/unbind-test.cc @@ -0,0 +1,223 @@ +// sherpa-onnx/csrc/unbind-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/unbind.h" + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(Ubind, Test1DTensors) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + EXPECT_EQ(ans.size(), shape[0]); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + EXPECT_EQ(ans[i].GetTensorData()[0], p[i]); + } + Print1D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print1D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 0); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test2DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + + Print2D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print2D(&ans[i]); + } + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + const float *pans = ans[i].GetTensorData(); + for (int32_t k = 0; k != static_cast(shape[1]); ++k, ++p) { + EXPECT_EQ(*p, pans[k]); + } + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 0); + Print2D(&v2); + + p = v.GetTensorMutableData(); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test2DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 1); + + Print2D(&v); + for (int32_t i = 0; i != static_cast(shape[1]); ++i) { + Print2D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 1); + Print2D(&v2); + + p = v.GetTensorMutableData(); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2, 5}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print3D(&ans[i]); + } + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + const float *pans = ans[i].GetTensorData(); + for (int32_t k = 0; k != static_cast(shape[1] * shape[2]); + ++k, ++p) { + EXPECT_EQ(*p, pans[k]); + } + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 0); + Print3D(&v2); + + p = v.GetTensorMutableData(); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2, 5}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 1); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[1]); ++i) { + Print3D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 1); + Print3D(&v2); + + p = v.GetTensorMutableData(); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim2) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2, 5}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 2); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[2]); ++i) { + Print3D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + Ort::Value v2 = Cat(allocator, vec, 2); + Print3D(&v2); + + p = v.GetTensorMutableData(); + const float *p2 = v2.GetTensorData(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/unbind.cc b/sherpa-onnx/csrc/unbind.cc new file mode 100644 index 00000000..ec8c96ee --- /dev/null +++ b/sherpa-onnx/csrc/unbind.cc @@ -0,0 +1,72 @@ +// sherpa-onnx/csrc/unbind.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/unbind.h" + +#include + +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +template +std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, + int32_t dim) { + std::vector shape = value->GetTensorTypeAndShapeInfo().GetShape(); + assert(dim >= 0); + assert(dim < static_cast(shape.size())); + int32_t n = static_cast(shape[dim]); + if (n == 1) { + std::vector ans; + ans.push_back(Clone(value)); + return ans; + } + + std::vector ans_shape = shape; + ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1 + + // allocator tensors + std::vector ans; + ans.reserve(n); + for (int32_t i = 0; i != n; ++i) { + Ort::Value t = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + ans.push_back(std::move(t)); + } + + auto leading_size = static_cast(std::accumulate( + shape.begin(), shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast(std::accumulate( + shape.begin() + dim + 1, shape.end(), 1, std::multiplies())); + + const T *src = value->GetTensorData(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (int32_t k = 0; k != n; ++k) { + T *dst = ans[k].GetTensorMutableData() + i * trailing_size; + std::copy(src, src + trailing_size, dst); + src += trailing_size; + } + } + + return std::move(ans); +} + +template std::vector Unbind(OrtAllocator *allocator, + const Ort::Value *value, + int32_t dim); + +template std::vector Unbind(OrtAllocator *allocator, + const Ort::Value *value, + int32_t dim); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/unbind.h b/sherpa-onnx/csrc/unbind.h new file mode 100644 index 00000000..9c799120 --- /dev/null +++ b/sherpa-onnx/csrc/unbind.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/unbind.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_UNBIND_H_ +#define SHERPA_ONNX_CSRC_UNBIND_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +/** It is similar to torch.unbind() but we keep the unbind dim to 1 in + * the output + * + * @param allocator Allocator to allocate space for the returned tensor + * @param value The tensor to unbind + * @param dim The dim along which to unbind the tensor + * + * @return Return a list of tensors + */ +template +std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, + int32_t dim); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_UNBIND_H_