Add Streaming zipformer (#50)
This commit is contained in:
7
.github/workflows/run-python-test.yaml
vendored
7
.github/workflows/run-python-test.yaml
vendored
@@ -33,8 +33,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
os: [ubuntu-latest, macos-latest] # windows-latest]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
exclude:
|
||||
- os: macos-latest
|
||||
python-version: "3.9"
|
||||
- os: macos-latest
|
||||
python-version: "3.10"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,3 +8,4 @@ sherpa-onnx-*
|
||||
__pycache__
|
||||
dist/
|
||||
sherpa_onnx.egg-info/
|
||||
.DS_Store
|
||||
|
||||
@@ -62,6 +62,7 @@ endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
enable_testing()
|
||||
include(googletest)
|
||||
endif()
|
||||
|
||||
add_subdirectory(sherpa-onnx)
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
function(download_googltest)
|
||||
include(FetchContent)
|
||||
|
||||
@@ -26,6 +11,7 @@ function(download_googltest)
|
||||
${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz
|
||||
${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz
|
||||
/tmp/googletest-1.13.0.tar.gz
|
||||
/star-fj/fangjun/download/github/googletest-1.13.0.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
add_library(sherpa-onnx-core
|
||||
cat.cc
|
||||
features.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
@@ -8,8 +9,11 @@ add_library(sherpa-onnx-core
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
online-transducer-model-config.cc
|
||||
online-transducer-model.cc
|
||||
online-zipformer-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
unbind.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
@@ -27,3 +31,32 @@ endif()
|
||||
|
||||
install(TARGETS sherpa-onnx-core DESTINATION lib)
|
||||
install(TARGETS sherpa-onnx DESTINATION bin)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
set(sherpa_onnx_test_srcs
|
||||
cat-test.cc
|
||||
unbind-test.cc
|
||||
)
|
||||
|
||||
function(sherpa_onnx_add_test source)
|
||||
get_filename_component(name ${source} NAME_WE)
|
||||
set(target_name ${name})
|
||||
add_executable(${target_name} "${source}")
|
||||
|
||||
target_link_libraries(${target_name}
|
||||
PRIVATE
|
||||
gtest
|
||||
gtest_main
|
||||
sherpa-onnx-core
|
||||
)
|
||||
|
||||
add_test(NAME "${target_name}"
|
||||
COMMAND
|
||||
$<TARGET_FILE:${target_name}>
|
||||
)
|
||||
endfunction()
|
||||
|
||||
foreach(source IN LISTS sherpa_onnx_test_srcs)
|
||||
sherpa_onnx_add_test(${source})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
254
sherpa-onnx/csrc/cat-test.cc
Normal file
254
sherpa-onnx/csrc/cat-test.cc
Normal file
@@ -0,0 +1,254 @@
|
||||
// sherpa-onnx/csrc/cat-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(Cat, Test1DTensors) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 1> a_shape{3};
|
||||
std::array<int64_t, 1> b_shape{6};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
|
||||
}
|
||||
|
||||
Print1D(&a);
|
||||
Print1D(&b);
|
||||
Print1D(&ans);
|
||||
}
|
||||
|
||||
TEST(Cat, Test2DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 2> a_shape{2, 3};
|
||||
std::array<int64_t, 2> b_shape{4, 3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
|
||||
}
|
||||
|
||||
Print2D(&a);
|
||||
Print2D(&b);
|
||||
Print2D(&ans);
|
||||
}
|
||||
|
||||
TEST(Cat, Test2DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 2> a_shape{4, 3};
|
||||
std::array<int64_t, 2> b_shape{4, 2};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 1);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
|
||||
++i, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
|
||||
++i, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
|
||||
Print2D(&a);
|
||||
Print2D(&b);
|
||||
Print2D(&ans);
|
||||
}
|
||||
|
||||
TEST(Cat, Test3DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 3, 2};
|
||||
std::array<int64_t, 3> b_shape{4, 3, 2};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print3D(&ans);
|
||||
}
|
||||
|
||||
TEST(Cat, Test3DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 2, 3};
|
||||
std::array<int64_t, 3> b_shape{2, 4, 3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 1);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
|
||||
++k, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
|
||||
++k, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print3D(&ans);
|
||||
}
|
||||
|
||||
TEST(Cat, Test3DTensorsDim2) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 3, 4};
|
||||
std::array<int64_t, 3> b_shape{2, 3, 5};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Cat(allocator, {&a, &b}, 2);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
|
||||
++k, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
|
||||
++k, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print3D(&ans);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
106
sherpa-onnx/csrc/cat.cc
Normal file
106
sherpa-onnx/csrc/cat.cc
Normal file
@@ -0,0 +1,106 @@
|
||||
// sherpa-onnx/csrc/cat.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool Compare(const std::vector<int64_t> &a,
|
||||
const std::vector<int64_t> &b, int32_t skip_dim) {
|
||||
if (a.size() != b.size()) return false;
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
|
||||
if (i == skip_dim) continue;
|
||||
|
||||
if (a[i] != b[i]) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void PrintShape(const std::vector<int64_t> &a) {
|
||||
for (auto i : a) {
|
||||
fprintf(stderr, "%d ", static_cast<int32_t>(i));
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
template <typename T /*=float*/>
|
||||
Ort::Value Cat(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim) {
|
||||
if (values.size() == 1u) {
|
||||
return Clone(values[0]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> v0_shape =
|
||||
values[0]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int64_t total_dim = v0_shape[dim];
|
||||
|
||||
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
|
||||
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
total_dim += s[dim];
|
||||
|
||||
bool ret = Compare(v0_shape, s, dim);
|
||||
if (!ret) {
|
||||
fprintf(stderr, "Incorrect shape in Cat !\n");
|
||||
|
||||
fprintf(stderr, "Shape for tensor 0: ");
|
||||
PrintShape(v0_shape);
|
||||
|
||||
fprintf(stderr, "Shape for tensor %d: ", i);
|
||||
PrintShape(s);
|
||||
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> ans_shape;
|
||||
ans_shape.reserve(v0_shape.size());
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
|
||||
ans_shape.push_back(total_dim);
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
|
||||
v0_shape.data() + v0_shape.size());
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(
|
||||
std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
|
||||
std::multiplies<int64_t>()));
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
|
||||
auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim];
|
||||
const T *src = values[n]->GetTensorData<T>();
|
||||
src += i * this_dim * trailing_size;
|
||||
|
||||
std::copy(src, src + this_dim * trailing_size, dst);
|
||||
dst += this_dim * trailing_size;
|
||||
}
|
||||
}
|
||||
|
||||
return std::move(ans);
|
||||
}
|
||||
|
||||
template Ort::Value Cat<float>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/cat.h
Normal file
28
sherpa-onnx/csrc/cat.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/cat.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_CAT_H_
|
||||
#define SHERPA_ONNX_CSRC_CAT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Cat a list of tensors along the given dim.
|
||||
*
|
||||
* @param allocator Allocator to allocate space for the returned tensor
|
||||
* @param values Pointer to a list of tensors. The shape of the tensor must
|
||||
* be the same except on the dim to be concatenated.
|
||||
* @param dim The dim along which to concatenate the input tensors
|
||||
*
|
||||
* @return Return the concatenated tensor
|
||||
*/
|
||||
template <typename T = float>
|
||||
Ort::Value Cat(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_CAT_H_
|
||||
41
sherpa-onnx/csrc/macros.h
Normal file
41
sherpa-onnx/csrc/macros.h
Normal file
@@ -0,0 +1,41 @@
|
||||
|
||||
// sherpa-onnx/csrc/macros.h
|
||||
//
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
|
||||
#define SHERPA_ONNX_CSRC_MACROS_H_
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
dst = atoi(value.get()); \
|
||||
if (dst <= 0) { \
|
||||
fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
|
||||
if (!ret) { \
|
||||
fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_MACROS_H_
|
||||
@@ -3,6 +3,8 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
@@ -11,23 +13,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
dst = atoi(value.get()); \
|
||||
if (dst <= 0) { \
|
||||
fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||
@@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
@@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
|
||||
std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
std::vector<const Ort::Value *> h_buf(batch_size);
|
||||
std::vector<const Ort::Value *> c_buf(batch_size);
|
||||
|
||||
std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
|
||||
rnn_hidden_size_};
|
||||
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
|
||||
float *dst_h = h.GetTensorMutableData<float>();
|
||||
float *dst_c = c.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *src_h =
|
||||
states[i][0].GetTensorData<float>() + layer * d_model_;
|
||||
|
||||
const float *src_c =
|
||||
states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
|
||||
|
||||
std::copy(src_h, src_h + d_model_, dst_h);
|
||||
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||
|
||||
dst_h += d_model_;
|
||||
dst_c += rnn_hidden_size_;
|
||||
}
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
assert(states[i].size() == 2);
|
||||
h_buf[i] = &states[i][0];
|
||||
c_buf[i] = &states[i][1];
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
Ort::Value h = Cat(allocator_, h_buf, 1);
|
||||
Ort::Value c = Cat(allocator_, c_buf, 1);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(h));
|
||||
ans.push_back(std::move(c));
|
||||
@@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
||||
std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
assert(states.size() == 2);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
|
||||
// allocate space
|
||||
std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
|
||||
std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
|
||||
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
|
||||
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
|
||||
|
||||
assert(h_vec.size() == batch_size);
|
||||
assert(c_vec.size() == batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
ans[i].push_back(std::move(h));
|
||||
ans[i].push_back(std::move(c));
|
||||
}
|
||||
|
||||
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *src_h = states[0].GetTensorData<float>() +
|
||||
layer * batch_size * d_model_ + i * d_model_;
|
||||
const float *src_c = states[1].GetTensorData<float>() +
|
||||
layer * batch_size * rnn_hidden_size_ +
|
||||
i * rnn_hidden_size_;
|
||||
|
||||
float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
|
||||
float *dst_c =
|
||||
ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
|
||||
|
||||
std::copy(src_h, src_h + d_model_, dst_h);
|
||||
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
|
||||
}
|
||||
ans[i].push_back(std::move(h_vec[i]));
|
||||
ans[i].push_back(std::move(c_vec[i]));
|
||||
}
|
||||
|
||||
return ans;
|
||||
@@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
|
||||
std::fill(h.GetTensorMutableData<float>(),
|
||||
h.GetTensorMutableData<float>() +
|
||||
num_encoder_layers_ * kBatchSize * d_model_,
|
||||
0);
|
||||
Fill<float>(&h, 0);
|
||||
|
||||
std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
|
||||
rnn_hidden_size_};
|
||||
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
|
||||
std::fill(c.GetTensorMutableData<float>(),
|
||||
c.GetTensorMutableData<float>() +
|
||||
num_encoder_layers_ * kBatchSize * rnn_hidden_size_,
|
||||
0);
|
||||
Fill<float>(&c, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
namespace sherpa_onnx {
|
||||
|
||||
enum class ModelType {
|
||||
kLstm,
|
||||
kZipformer,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
@@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||
|
||||
if (model_type.get() == std::string("lstm")) {
|
||||
return ModelType::kLstm;
|
||||
} else if (model_type.get() == std::string("zipformer")) {
|
||||
return ModelType::kZipformer;
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
@@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
switch (model_type) {
|
||||
case ModelType::kLstm:
|
||||
return std::make_unique<OnlineLstmTransducerModel>(config);
|
||||
case ModelType::kZipformer:
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||
case ModelType::kUnkown:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
456
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
Normal file
456
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
Normal file
@@ -0,0 +1,456 @@
|
||||
// sherpa-onnx/csrc/online-zipformer-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
InitEncoder(config.encoder_filename);
|
||||
InitDecoder(config.decoder_filename);
|
||||
InitJoiner(config.joiner_filename);
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||
|
||||
if (config_.debug) {
|
||||
auto print = [](const std::vector<int32_t> &v, const char *name) {
|
||||
fprintf(stderr, "%s: ", name);
|
||||
for (auto i : v) {
|
||||
fprintf(stderr, "%d ", i);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
};
|
||||
print(encoder_dims_, "encoder_dims");
|
||||
print(attention_dims_, "attention_dims");
|
||||
print(num_encoder_layers_, "num_encoder_layers");
|
||||
print(cnn_module_kernels_, "cnn_module_kernels");
|
||||
print(left_context_len_, "left_context_len");
|
||||
fprintf(stderr, "T: %d\n", T_);
|
||||
fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_);
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---decoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---joiner---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
|
||||
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(states[0].size());
|
||||
|
||||
// cached_len
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][i];
|
||||
}
|
||||
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_avg
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders + i];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_key
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders * 2 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_val
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders * 3 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_val2
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders * 4 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_conv1
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders * 5 + i];
|
||||
}
|
||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
// cached_conv2
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders * 6 + i];
|
||||
}
|
||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>>
|
||||
OnlineZipformerTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
assert(states.size() == num_encoder_layers_.size() * 7);
|
||||
|
||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
int32_t num_encoders = num_encoder_layers_.size();
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
ans.resize(batch_size);
|
||||
|
||||
// cached_len
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_avg
|
||||
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_key
|
||||
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_val
|
||||
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_val2
|
||||
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_conv1
|
||||
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
// cached_conv2
|
||||
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
ans[n].push_back(std::move(v[n]));
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
|
||||
// Please see
|
||||
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673
|
||||
// for details
|
||||
|
||||
int32_t n = static_cast<int32_t>(encoder_dims_.size());
|
||||
std::vector<Ort::Value> cached_len_vec;
|
||||
std::vector<Ort::Value> cached_avg_vec;
|
||||
std::vector<Ort::Value> cached_key_vec;
|
||||
std::vector<Ort::Value> cached_val_vec;
|
||||
std::vector<Ort::Value> cached_val2_vec;
|
||||
std::vector<Ort::Value> cached_conv1_vec;
|
||||
std::vector<Ort::Value> cached_conv2_vec;
|
||||
|
||||
cached_len_vec.reserve(n);
|
||||
cached_avg_vec.reserve(n);
|
||||
cached_key_vec.reserve(n);
|
||||
cached_val_vec.reserve(n);
|
||||
cached_val2_vec.reserve(n);
|
||||
cached_conv1_vec.reserve(n);
|
||||
cached_conv2_vec.reserve(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
{
|
||||
std::array<int64_t, 2> s{num_encoder_layers_[i], 1};
|
||||
auto v =
|
||||
Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
|
||||
Fill<int64_t>(&v, 0);
|
||||
cached_len_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 3> s{num_encoder_layers_[i], 1, encoder_dims_[i]};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_avg_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
|
||||
attention_dims_[i]};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_key_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
|
||||
attention_dims_[i] / 2};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_val_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
|
||||
attention_dims_[i] / 2};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_val2_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
|
||||
cnn_module_kernels_[i] - 1};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_conv1_vec.push_back(std::move(v));
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
|
||||
cnn_module_kernels_[i] - 1};
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
cached_conv2_vec.push_back(std::move(v));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(n * 7);
|
||||
|
||||
for (auto &v : cached_len_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_avg_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_key_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_val_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_val2_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_conv1_vec) ans.push_back(std::move(v));
|
||||
for (auto &v : cached_conv2_vec) ans.push_back(std::move(v));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> encoder_inputs;
|
||||
encoder_inputs.reserve(1 + states.size());
|
||||
|
||||
encoder_inputs.push_back(std::move(features));
|
||||
for (auto &v : states) {
|
||||
encoder_inputs.push_back(std::move(v));
|
||||
}
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> next_states;
|
||||
next_states.reserve(states.size());
|
||||
|
||||
for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
|
||||
next_states.push_back(std::move(encoder_out[i]));
|
||||
}
|
||||
|
||||
return {std::move(encoder_out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) {
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
std::array<int64_t, 2> shape{batch_size, context_size_};
|
||||
Ort::Value decoder_input =
|
||||
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (const auto &r : results) {
|
||||
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
|
||||
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size_;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value OnlineZipformerTransducerModel::RunDecoder(
|
||||
Ort::Value decoder_input) {
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
|
||||
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||
return std::move(decoder_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out,
|
||||
Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit =
|
||||
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
|
||||
joiner_input.size(), joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
99
sherpa-onnx/csrc/online-zipformer-transducer-model.h
Normal file
99
sherpa-onnx/csrc/online-zipformer-transducer-model.h
Normal file
@@ -0,0 +1,99 @@
|
||||
// sherpa-onnx/csrc/online-zipformer-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
public:
|
||||
explicit OnlineZipformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const override;
|
||||
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) override;
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) override;
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
|
||||
|
||||
int32_t ContextSize() const override { return context_size_; }
|
||||
|
||||
int32_t ChunkSize() const override { return T_; }
|
||||
|
||||
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
||||
|
||||
int32_t VocabSize() const override { return vocab_size_; }
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_input_names_;
|
||||
std::vector<const char *> joiner_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_output_names_;
|
||||
std::vector<const char *> joiner_output_names_ptr_;
|
||||
|
||||
OnlineTransducerModelConfig config_;
|
||||
|
||||
std::vector<int32_t> encoder_dims_;
|
||||
std::vector<int32_t> attention_dims_;
|
||||
std::vector<int32_t> num_encoder_layers_;
|
||||
std::vector<int32_t> cnn_module_kernels_;
|
||||
std::vector<int32_t> left_context_len_;
|
||||
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
|
||||
int32_t context_size_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
|
||||
@@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value Clone(Ort::Value *v) {
|
||||
Ort::Value Clone(const Ort::Value *v) {
|
||||
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
||||
std::vector<int64_t> shape = type_and_shape.GetShape();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(),
|
||||
shape.data(), shape.size());
|
||||
switch (type_and_shape.GetElementType()) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
exit(-1);
|
||||
// unreachable code
|
||||
return Ort::Value{nullptr};
|
||||
}
|
||||
}
|
||||
|
||||
void Print1D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
fprintf(stderr, "%.3f ", d[i]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void Print2D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
|
||||
for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
|
||||
fprintf(stderr, "%.3f ", *d);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void Print3D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
|
||||
fprintf(stderr, "---plane %d---\n", p);
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(shape[1]); ++r) {
|
||||
for (int32_t c = 0; c != static_cast<int32_t>(shape[2]); ++c, ++d) {
|
||||
fprintf(stderr, "%.3f ", *d);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os,
|
||||
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||
|
||||
// Return a shallow copy of v
|
||||
Ort::Value Clone(Ort::Value *v);
|
||||
Ort::Value Clone(const Ort::Value *v);
|
||||
|
||||
// Print a 1-D tensor to stderr
|
||||
void Print1D(Ort::Value *v);
|
||||
|
||||
// Print a 2-D tensor to stderr
|
||||
void Print2D(Ort::Value *v);
|
||||
|
||||
// Print a 3-D tensor to stderr
|
||||
void Print3D(Ort::Value *v);
|
||||
|
||||
template <typename T = float>
|
||||
void Fill(Ort::Value *tensor, T value) {
|
||||
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
auto p = tensor->GetTensorMutableData<T>();
|
||||
std::fill(p, p + n, value);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
30
sherpa-onnx/csrc/text-utils.cc
Normal file
30
sherpa-onnx/csrc/text-utils.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
// sherpa-onnx/csrc/text-utils.cc
|
||||
//
|
||||
// Copyright 2009-2011 Saarland University; Microsoft Corporation
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// This file is copied/modified from
|
||||
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void SplitStringToVector(const std::string &full, const char *delim,
|
||||
bool omit_empty_strings,
|
||||
std::vector<std::string> *out) {
|
||||
size_t start = 0, found = 0, end = full.size();
|
||||
out->clear();
|
||||
while (found != std::string::npos) {
|
||||
found = full.find_first_of(delim, start);
|
||||
// start != end condition is for when the delimiter is at the end
|
||||
if (!omit_empty_strings || (found != start && start != end))
|
||||
out->push_back(full.substr(start, found - start));
|
||||
start = found + 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
85
sherpa-onnx/csrc/text-utils.h
Normal file
85
sherpa-onnx/csrc/text-utils.h
Normal file
@@ -0,0 +1,85 @@
|
||||
// sherpa-onnx/csrc/text-utils.h
|
||||
//
|
||||
// Copyright 2009-2011 Saarland University; Microsoft Corporation
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
|
||||
_strtoi64(cur_cstr, end_cstr, 10);
|
||||
#else
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
|
||||
#endif
|
||||
|
||||
// This file is copied/modified from
|
||||
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Split a string using any of the single character delimiters.
|
||||
/// If omit_empty_strings == true, the output will contain any
|
||||
/// nonempty strings after splitting on any of the
|
||||
/// characters in the delimiter. If omit_empty_strings == false,
|
||||
/// the output will contain n+1 strings if there are n characters
|
||||
/// in the set "delim" within the input string. In this case
|
||||
/// the empty string is split to a single empty string.
|
||||
void SplitStringToVector(const std::string &full, const char *delim,
|
||||
bool omit_empty_strings,
|
||||
std::vector<std::string> *out);
|
||||
|
||||
/**
|
||||
\brief Split a string (e.g. 1:2:3) into a vector of integers.
|
||||
|
||||
\param [in] delim String containing a list of characters, any of which
|
||||
is allowed as a delimiter.
|
||||
\param [in] omit_empty_strings If true, empty strings between delimiters are
|
||||
allowed and will not produce an output integer; if false,
|
||||
instances of characters in 'delim' that are consecutive or
|
||||
at the start or end of the string would be an error.
|
||||
You'll normally want this to be true if 'delim' consists
|
||||
of spaces, and false otherwise.
|
||||
\param [out] out The output list of integers.
|
||||
*/
|
||||
template <class I>
|
||||
bool SplitStringToIntegers(const std::string &full, const char *delim,
|
||||
bool omit_empty_strings, // typically false [but
|
||||
// should probably be true
|
||||
// if "delim" is spaces].
|
||||
std::vector<I> *out) {
|
||||
static_assert(std::is_integral<I>::value, "");
|
||||
if (*(full.c_str()) == '\0') {
|
||||
out->clear();
|
||||
return true;
|
||||
}
|
||||
std::vector<std::string> split;
|
||||
SplitStringToVector(full, delim, omit_empty_strings, &split);
|
||||
out->resize(split.size());
|
||||
for (size_t i = 0; i < split.size(); i++) {
|
||||
const char *this_str = split[i].c_str();
|
||||
char *end = NULL;
|
||||
int64_t j = 0;
|
||||
j = SHERPA_ONNX_STRTOLL(this_str, &end);
|
||||
if (end == this_str || *end != '\0') {
|
||||
out->clear();
|
||||
return false;
|
||||
} else {
|
||||
I jI = static_cast<I>(j);
|
||||
if (static_cast<int64_t>(jI) != j) {
|
||||
// output type cannot fit this integer.
|
||||
out->clear();
|
||||
return false;
|
||||
}
|
||||
(*out)[i] = jI;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
223
sherpa-onnx/csrc/unbind-test.cc
Normal file
223
sherpa-onnx/csrc/unbind-test.cc
Normal file
@@ -0,0 +1,223 @@
|
||||
// sherpa-onnx/csrc/unbind-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(Ubind, Test1DTensors) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 1> shape{3};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 0);
|
||||
EXPECT_EQ(ans.size(), shape[0]);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
EXPECT_EQ(ans[i].GetTensorData<float>()[0], p[i]);
|
||||
}
|
||||
Print1D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
Print1D(&ans[i]);
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 0);
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ubind, Test2DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 2> shape{3, 2};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 0);
|
||||
|
||||
Print2D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
Print2D(&ans[i]);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
const float *pans = ans[i].GetTensorData<float>();
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(shape[1]); ++k, ++p) {
|
||||
EXPECT_EQ(*p, pans[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 0);
|
||||
Print2D(&v2);
|
||||
|
||||
p = v.GetTensorMutableData<float>();
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ubind, Test2DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 2> shape{3, 2};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 1);
|
||||
|
||||
Print2D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
|
||||
Print2D(&ans[i]);
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 1);
|
||||
Print2D(&v2);
|
||||
|
||||
p = v.GetTensorMutableData<float>();
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ubind, Test3DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 3> shape{3, 2, 5};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||
++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 0);
|
||||
|
||||
Print3D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
Print3D(&ans[i]);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
const float *pans = ans[i].GetTensorData<float>();
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(shape[1] * shape[2]);
|
||||
++k, ++p) {
|
||||
EXPECT_EQ(*p, pans[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 0);
|
||||
Print3D(&v2);
|
||||
|
||||
p = v.GetTensorMutableData<float>();
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ubind, Test3DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 3> shape{3, 2, 5};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||
++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 1);
|
||||
|
||||
Print3D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
|
||||
Print3D(&ans[i]);
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 1);
|
||||
Print3D(&v2);
|
||||
|
||||
p = v.GetTensorMutableData<float>();
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ubind, Test3DTensorsDim2) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 3> shape{3, 2, 5};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||
++i) {
|
||||
p[i] = i;
|
||||
}
|
||||
auto ans = Unbind(allocator, &v, 2);
|
||||
|
||||
Print3D(&v);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[2]); ++i) {
|
||||
Print3D(&ans[i]);
|
||||
}
|
||||
|
||||
// For Cat
|
||||
std::vector<const Ort::Value *> vec(ans.size());
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
|
||||
vec[i] = &ans[i];
|
||||
}
|
||||
Ort::Value v2 = Cat(allocator, vec, 2);
|
||||
Print3D(&v2);
|
||||
|
||||
p = v.GetTensorMutableData<float>();
|
||||
const float *p2 = v2.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
|
||||
EXPECT_EQ(p[i], p2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
72
sherpa-onnx/csrc/unbind.cc
Normal file
72
sherpa-onnx/csrc/unbind.cc
Normal file
@@ -0,0 +1,72 @@
|
||||
// sherpa-onnx/csrc/unbind.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
template <typename T /*= float*/>
|
||||
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
|
||||
int32_t dim) {
|
||||
std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(dim >= 0);
|
||||
assert(dim < static_cast<int32_t>(shape.size()));
|
||||
int32_t n = static_cast<int32_t>(shape[dim]);
|
||||
if (n == 1) {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.push_back(Clone(value));
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<int64_t> ans_shape = shape;
|
||||
ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
|
||||
|
||||
// allocator tensors
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
Ort::Value t = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
ans.push_back(std::move(t));
|
||||
}
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(std::accumulate(
|
||||
shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
|
||||
|
||||
const T *src = value->GetTensorData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
|
||||
std::copy(src, src + trailing_size, dst);
|
||||
src += trailing_size;
|
||||
}
|
||||
}
|
||||
|
||||
return std::move(ans);
|
||||
}
|
||||
|
||||
template std::vector<Ort::Value> Unbind<float>(OrtAllocator *allocator,
|
||||
const Ort::Value *value,
|
||||
int32_t dim);
|
||||
|
||||
template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
|
||||
const Ort::Value *value,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/unbind.h
Normal file
28
sherpa-onnx/csrc/unbind.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/unbind.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_UNBIND_H_
|
||||
#define SHERPA_ONNX_CSRC_UNBIND_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** It is similar to torch.unbind() but we keep the unbind dim to 1 in
|
||||
* the output
|
||||
*
|
||||
* @param allocator Allocator to allocate space for the returned tensor
|
||||
* @param value The tensor to unbind
|
||||
* @param dim The dim along which to unbind the tensor
|
||||
*
|
||||
* @return Return a list of tensors
|
||||
*/
|
||||
template <typename T = float>
|
||||
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_UNBIND_H_
|
||||
Reference in New Issue
Block a user