Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)

Thanks to @Peakyxh for providing pre-built onnxruntime libraries 
with CUDA support for Linux aarch64.

Tested on Jetson nano b01
This commit is contained in:
Fangjun Kuang
2024-11-01 11:16:28 +08:00
committed by GitHub
parent a3c89aa0d8
commit 9ab89c33bc
41 changed files with 537 additions and 291 deletions

View File

@@ -34,11 +34,12 @@ concurrency:
jobs: jobs:
aarch64_linux_gnu_shared: aarch64_linux_gnu_shared:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
name: aarch64 shared lib test name: aarch64 shared GPU ${{ matrix.gpu }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
gpu: [ON, OFF]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -79,15 +80,24 @@ jobs:
make -j2 make -j2
make install make install
- name: cache-toolchain - name: cache-toolchain (CPU)
id: cache-toolchain if: matrix.gpu == 'OFF'
id: cache-toolchain-cpu
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: toolchain path: toolchain
key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
- name: Download toolchain - name: cache-toolchain (GPU)
if: steps.cache-toolchain.outputs.cache-hit != 'true' if: matrix.gpu == 'ON'
id: cache-toolchain-gpu
uses: actions/cache@v4
with:
path: toolchain
key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
- name: Download toolchain (CPU, gcc 7.5)
if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF'
shell: bash shell: bash
run: | run: |
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
@@ -95,6 +105,15 @@ jobs:
mkdir $GITHUB_WORKSPACE/toolchain mkdir $GITHUB_WORKSPACE/toolchain
tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
- name: Download toolchain (GPU, gcc 10.3)
if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON'
shell: bash
run: |
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
mkdir $GITHUB_WORKSPACE/toolchain
tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
- name: Set environment variable - name: Set environment variable
if: steps.cache-build-result.outputs.cache-hit != 'true' if: steps.cache-build-result.outputs.cache-hit != 'true'
shell: bash shell: bash
@@ -103,19 +122,31 @@ jobs:
echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH" echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
ls -lh "$GITHUB_WORKSPACE/toolchain/bin" ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
if [[ ${{ matrix.gpu }} == OFF ]]; then
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
else
echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV"
echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV"
fi
- name: Display toolchain info - name: Display toolchain info
shell: bash shell: bash
run: | run: |
if [[ ${{ matrix.gpu }} == OFF ]]; then
which aarch64-linux-gnu-gcc
aarch64-linux-gnu-gcc --version aarch64-linux-gnu-gcc --version
else
which aarch64-none-linux-gnu-gcc
aarch64-none-linux-gnu-gcc --version
fi
- name: Display qemu-aarch64 -h - name: Display qemu-aarch64 -h
shell: bash shell: bash
run: | run: |
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
qemu-aarch64 -h qemu-aarch64 -h
- name: build aarch64-linux-gnu - name: build aarch64-linux-gnu
@@ -127,6 +158,7 @@ jobs:
cmake --version cmake --version
export BUILD_SHARED_LIBS=ON export BUILD_SHARED_LIBS=ON
export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }}
./build-aarch64-linux-gnu.sh ./build-aarch64-linux-gnu.sh
@@ -140,7 +172,11 @@ jobs:
run: | run: |
export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
if [[ ${{ matrix.gpu }} == OFF ]]; then
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
else
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
fi
ls -lh ./build-aarch64-linux-gnu/bin ls -lh ./build-aarch64-linux-gnu/bin
@@ -151,11 +187,20 @@ jobs:
- name: Copy files - name: Copy files
shell: bash shell: bash
run: | run: |
if [[ ${{ matrix.gpu }} == OFF ]]; then
aarch64-linux-gnu-strip --version aarch64-linux-gnu-strip --version
else
aarch64-none-linux-gnu-strip --version
fi
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
if [[ ${{ matrix.gpu }} == OFF ]]; then
dst=${dst}-cpu
else
dst=${dst}-gpu
fi
mkdir $dst mkdir $dst
cp -a build-aarch64-linux-gnu/install/bin $dst/ cp -a build-aarch64-linux-gnu/install/bin $dst/
@@ -166,7 +211,11 @@ jobs:
ls -lh $dst/bin/ ls -lh $dst/bin/
echo "strip" echo "strip"
if [[ ${{ matrix.gpu }} == OFF ]]; then
aarch64-linux-gnu-strip $dst/bin/* aarch64-linux-gnu-strip $dst/bin/*
else
aarch64-none-linux-gnu-strip $dst/bin/*
fi
tree $dst tree $dst
@@ -174,8 +223,8 @@ jobs:
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: sherpa-onnx-linux-aarch64-shared name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }}
path: sherpa-onnx-*linux-aarch64-shared.tar.bz2 path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2
# https://huggingface.co/docs/hub/spaces-github-actions # https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface - name: Publish to huggingface
@@ -198,7 +247,7 @@ jobs:
cd huggingface cd huggingface
mkdir -p aarch64 mkdir -p aarch64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
git status git status
git lfs track "*.bz2" git lfs track "*.bz2"

View File

@@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
BUILD_SHARED_LIBS=OFF BUILD_SHARED_LIBS=OFF
fi fi
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
# By default, use CPU
SHERPA_ONNX_ENABLE_GPU=OFF
# If you use GPU, then please make sure you have NVIDIA GPUs on your board.
# It uses onnxruntime 1.11.0.
#
# Tested on Jetson Nano B01
fi
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then
# Build shared libs if building GPU is enabled.
BUILD_SHARED_LIBS=ON
fi
cmake \ cmake \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
@@ -51,6 +66,7 @@ cmake \
-DBUILD_ESPEAK_NG_TESTS=OFF \ -DBUILD_ESPEAK_NG_TESTS=OFF \
-DCMAKE_INSTALL_PREFIX=./install \ -DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \ -DCMAKE_BUILD_TYPE=Release \
-DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \ -DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \ -DSHERPA_ONNX_ENABLE_PYTHON=OFF \

View File

@@ -0,0 +1,101 @@
# Copyright (c) 2022-2024 Xiaomi Corporation
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux)
message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}")
endif()
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}")
endif()
if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()
if(NOT SHERPA_ONNX_ENABLE_GPU)
message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
endif()
set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7")
# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
# You can add more if you want.
set(possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
/tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(onnxruntime_URL "${f}")
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
set(onnxruntime_URL2)
break()
endif()
endforeach()
FetchContent_Declare(onnxruntime
URL
${onnxruntime_URL}
${onnxruntime_URL2}
URL_HASH ${onnxruntime_HASH}
)
FetchContent_GetProperties(onnxruntime)
if(NOT onnxruntime_POPULATED)
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
FetchContent_Populate(onnxruntime)
endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
find_library(location_onnxruntime onnxruntime
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
add_library(onnxruntime SHARED IMPORTED)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
)
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
set_target_properties(onnxruntime_providers_cuda PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
)
message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
# for libonnxruntime_providers_shared.so
find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
add_library(onnxruntime_providers_shared SHARED IMPORTED)
set_target_properties(onnxruntime_providers_shared PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib}
)
message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}")
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*")
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
install(FILES ${onnxruntime_lib_files} DESTINATION lib)

View File

@@ -13,7 +13,9 @@ function(download_onnxruntime)
include(onnxruntime-linux-riscv64-static) include(onnxruntime-linux-riscv64-static)
endif() endif()
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
if(BUILD_SHARED_LIBS) if(SHERPA_ONNX_ENABLE_GPU)
include(onnxruntime-linux-aarch64-gpu)
elseif(BUILD_SHARED_LIBS)
include(onnxruntime-linux-aarch64) include(onnxruntime-linux-aarch64)
else() else()
include(onnxruntime-linux-aarch64-static) include(onnxruntime-linux-aarch64-static)

View File

@@ -1,18 +1,18 @@
function(download_piper_phonemize) function(download_piper_phonemize)
include(FetchContent) include(FetchContent)
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip") set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip") set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14") set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b")
# If you don't have access to the Internet, # If you don't have access to the Internet,
# please pre-download kaldi-decoder # please pre-download kaldi-decoder
set(possible_file_locations set(possible_file_locations
$ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip $ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip ${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip ${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
/tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip /tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
/star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip /star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
) )
foreach(f IN LISTS possible_file_locations) foreach(f IN LISTS possible_file_locations)

View File

@@ -7,6 +7,8 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <utility>
#if __ANDROID_API__ >= 8 #if __ANDROID_API__ >= 8
#include "android/log.h" #include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \ #define SHERPA_ONNX_LOGE(...) \
@@ -38,14 +40,13 @@
// Read an integer // Read an integer
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
\ \
dst = atoi(value.get()); \ dst = atoi(value.c_str()); \
if (dst < 0) { \ if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \ exit(-1); \
@@ -54,12 +55,11 @@
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
dst = default_value; \ dst = default_value; \
} else { \ } else { \
dst = atoi(value.get()); \ dst = atoi(value.c_str()); \
if (dst < 0) { \ if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \ exit(-1); \
@@ -70,16 +70,15 @@
// read a vector of integers // read a vector of integers
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
\ \
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
if (!ret) { \ if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \ exit(-1); \
} \ } \
} while (0) } while (0)
@@ -87,16 +86,15 @@
// read a vector of floats // read a vector of floats
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
\ \
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
if (!ret) { \ if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \ exit(-1); \
} \ } \
} while (0) } while (0)
@@ -104,17 +102,16 @@
// read a vector of strings // read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
SplitStringToVector(value.get(), ",", false, &dst); \ SplitStringToVector(value.c_str(), ",", false, &dst); \
\ \
if (dst.empty()) { \ if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \ value.c_str(), src_key); \
exit(-1); \ exit(-1); \
} \ } \
} while (0) } while (0)
@@ -122,17 +119,16 @@
// read a vector of strings separated by sep // read a vector of strings separated by sep
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ #define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
SplitStringToVector(value.get(), sep, false, &dst); \ SplitStringToVector(value.c_str(), sep, false, &dst); \
\ \
if (dst.empty()) { \ if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \ value.c_str(), src_key); \
exit(-1); \ exit(-1); \
} \ } \
} while (0) } while (0)
@@ -140,14 +136,13 @@
// Read a string // Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
\ \
dst = value.get(); \ dst = std::move(value); \
if (dst.empty()) { \ if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \ exit(-1); \
@@ -156,25 +151,23 @@
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \ exit(-1); \
} \ } \
\ \
dst = value.get(); \ dst = std::move(value); \
} while (0) } while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \ default_value) \
do { \ do { \
auto value = \ auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (value.empty()) { \
if (!value) { \
dst = default_value; \ dst = default_value; \
} else { \ } else { \
dst = value.get(); \ dst = std::move(value); \
if (dst.empty()) { \ if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \ exit(-1); \

View File

@@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; } int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
return std::move(ans[0]); return std::move(ans[0]);
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
const OfflineCtTransformerModelMetaData &GetModelMetadata() const { const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
return meta_data_; return meta_data_;

View File

@@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto model_type = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type) { if (model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n" "No model_type in the metadata!\n"
"If you are using models from NeMo, please refer to\n" "If you are using models from NeMo, please refer to\n"
@@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown; return ModelType::kUnknown;
} }
if (model_type.get() == std::string("EncDecCTCModelBPE")) { if (model_type == "EncDecCTCModelBPE") {
return ModelType::kEncDecCTCModelBPE; return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("EncDecCTCModel")) { } else if (model_type == "EncDecCTCModel") {
return ModelType::kEncDecCTCModel; return ModelType::kEncDecCTCModel;
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { } else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
return ModelType::kEncDecHybridRNNTCTCBPEModel; return ModelType::kEncDecHybridRNNTCTCBPEModel;
} else if (model_type.get() == std::string("tdnn")) { } else if (model_type == "tdnn") {
return ModelType::kTdnn; return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) { } else if (model_type == "zipformer2_ctc") {
return ModelType::kZipformerCtc; return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) { } else if (model_type == "wenet_ctc") {
return ModelType::kWenetCtc; return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) { } else if (model_type == "telespeech_ctc") {
return ModelType::kTeleSpeechCtc; return ModelType::kTeleSpeechCtc;
} else { } else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown; return ModelType::kUnknown;
} }
} }

View File

@@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
return {std::move(cached_decoder_out[0]), std::move(next_states)}; return {std::move(cached_decoder_out[0]), std::move(next_states)};
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void InitPreprocessor(void *model_data, size_t model_data_length) { void InitPreprocessor(void *model_data, size_t model_data_length) {

View File

@@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; } int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; } std::string FeatureNormalizationMethod() const { return normalize_type_; }

View File

@@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; } const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type_ptr) { if (!model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n" "No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata" "Please refer to the following URLs to add metadata"
@@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n"); "\n");
exit(-1); exit(-1);
} }
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" || if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") { model_type == "zipformer2") {
@@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type_ptr) { if (model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n" "No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata" "Please refer to the following URLs to add metadata"
@@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n"); "\n");
exit(-1); exit(-1);
} }
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" || if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") { model_type == "zipformer2") {

View File

@@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
return meta_data_; return meta_data_;
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; } int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; } int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
int32_t VocabSize() const { return vocab_size_; } int32_t VocabSize() const { return vocab_size_; }
int32_t ContextSize() const { return context_size_; } int32_t ContextSize() const { return context_size_; }
int32_t SubsamplingFactor() const { return 4; } int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
Ort::Value BuildDecoderInput( Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results, const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const { int32_t end_index) {
assert(end_index <= results.size()); assert(end_index <= results.size());
int32_t batch_size = end_index; int32_t batch_size = end_index;
@@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
} }
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const { int32_t end_index) {
assert(end_index <= results.size()); assert(end_index <= results.size());
int32_t batch_size = end_index; int32_t batch_size = end_index;

View File

@@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
return std::move(logit[0]); return std::move(logit[0]);
} }
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
s0_shape.size()); s0_shape.size());
@@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; } int32_t SubsamplingFactor() const { return subsampling_factor_; }
int32_t VocabSize() const { return vocab_size_; } int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; } std::string FeatureNormalizationMethod() const { return normalize_type_; }

View File

@@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; } int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }

View File

@@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; } int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; } int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return 4; } int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void Init(void *model_data, size_t model_data_length) { void Init(void *model_data, size_t model_data_length) {

View File

@@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
return {std::move(ans[0]), std::move(ans[1])}; return {std::move(ans[0]), std::move(ans[1])};
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
return meta_data_; return meta_data_;

View File

@@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
conv_vec[i] = &states[i][1]; conv_vec[i] = &states[i][1];
} }
Ort::Value attn = Cat(allocator_, attn_vec, 2); auto allocator =
Ort::Value conv = Cat(allocator_, conv_vec, 2); const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
Ort::Value attn = Cat(allocator, attn_vec, 2);
Ort::Value conv = Cat(allocator, conv_vec, 2);
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
ans.reserve(2); ans.reserve(2);
@@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size); std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2); auto allocator =
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2); const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
assert(attn_vec.size() == batch_size); assert(attn_vec.size() == batch_size);
assert(conv_vec.size() == batch_size); assert(conv_vec.size() == batch_size);

View File

@@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
h_buf[i] = &states[i][0]; h_buf[i] = &states[i][0];
c_buf[i] = &states[i][1]; c_buf[i] = &states[i][1];
} }
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
Ort::Value h = Cat(allocator_, h_buf, 1); Ort::Value h = Cat(allocator, h_buf, 1);
Ort::Value c = Cat(allocator_, c_buf, 1); Ort::Value c = Cat(allocator, c_buf, 1);
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
ans.reserve(2); ans.reserve(2);
@@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size); std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1); auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
assert(h_vec.size() == batch_size); assert(h_vec.size() == batch_size);
assert(c_vec.size() == batch_size); assert(c_vec.size() == batch_size);

View File

@@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
int32_t ChunkShift() const { return chunk_shift_; } int32_t ChunkShift() const { return chunk_shift_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors // Return a vector containing 3 tensors
// - cache_last_channel // - cache_last_channel
@@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
} }
std::vector<Ort::Value> StackStates( std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const { std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size()); int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) { if (batch_size == 1) {
return std::move(states[0]); return std::move(states[0]);
@@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
std::vector<Ort::Value> states) const { std::vector<Ort::Value> states) const {
assert(states.size() == 3); assert(states.size() == 3);
auto allocator = const_cast<Impl *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans; std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
for (int32_t i = 0; i != 3; ++i) { for (int32_t i = 0; i != 3; ++i) {
std::vector<Ort::Value> v; std::vector<Ort::Value> v;
if (i == 2) { if (i == 2) {
v = Unbind<int64_t>(allocator_, &states[i], 0); v = Unbind<int64_t>(allocator, &states[i], 0);
} else { } else {
v = Unbind(allocator_, &states[i], 0); v = Unbind(allocator, &states[i], 0);
} }
assert(v.size() == batch_size); assert(v.size() == batch_size);

View File

@@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; } const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
private: private:
void InitEncoder(void *model_data, size_t model_data_length) { void InitEncoder(void *model_data, size_t model_data_length) {

View File

@@ -5,10 +5,10 @@
#include "sherpa-onnx/csrc/online-rnn-lm.h" #include "sherpa-onnx/csrc/online-rnn-lm.h"
#include <algorithm>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <algorithm>
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
@@ -77,12 +77,12 @@ class OnlineRnnLM::Impl {
Ort::Value x = Ort::Value::CreateTensor<int64_t>( Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size()); allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>(); int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
ys.end() - 1, p_x); p_x);
// streaming forward by NN LM // streaming forward by NN LM
auto out = ScoreToken(std::move(x), auto out =
Convert(std::move(h.nn_lm_states))); ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
// update NN LM score in hyp // update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>(); const float *p_nll = out.first.GetTensorData<float>();
@@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
} }
// get init states for classic rescore // get init states for classic rescore
std::vector<Ort::Value> GetInitStates() const { std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
ans.reserve(init_states_.size()); ans.reserve(init_states_.size());
@@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScoreSF(scale, hyp); return impl_->ComputeLMScoreSF(scale, hyp);
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto model_type = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type) { if (model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n" "No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall " "Please make sure you are using the latest export-onnx.py from icefall "
@@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown; return ModelType::kUnknown;
} }
if (model_type.get() == std::string("conformer")) { if (model_type == "conformer") {
return ModelType::kConformer; return ModelType::kConformer;
} else if (model_type.get() == std::string("lstm")) { } else if (model_type == "lstm") {
return ModelType::kLstm; return ModelType::kLstm;
} else if (model_type.get() == std::string("zipformer")) { } else if (model_type == "zipformer") {
return ModelType::kZipformer; return ModelType::kZipformer;
} else if (model_type.get() == std::string("zipformer2")) { } else if (model_type == "zipformer2") {
return ModelType::kZipformer2; return ModelType::kZipformer2;
} else { } else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown; return ModelType::kUnknown;
} }
} }

View File

@@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
int32_t VocabSize() const { return vocab_size_; } int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; } std::string FeatureNormalizationMethod() const { return normalize_type_; }
@@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
auto allocator = const_cast<Impl *>(this)->allocator_;
// stack cache_last_channel // stack cache_last_channel
std::vector<const Ort::Value *> buf(batch_size); std::vector<const Ort::Value *> buf(batch_size);
@@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
Ort::Value c{nullptr}; Ort::Value c{nullptr};
if (i == 2) { if (i == 2) {
c = Cat<int64_t>(allocator_, buf, 0); c = Cat<int64_t>(allocator, buf, 0);
} else { } else {
c = Cat(allocator_, buf, 0); c = Cat(allocator, buf, 0);
} }
ans.push_back(std::move(c)); ans.push_back(std::move(c));
@@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
} }
std::vector<std::vector<Ort::Value>> UnStackStates( std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const { std::vector<Ort::Value> states) {
assert(states.size() == 3); assert(states.size() == 3);
std::vector<std::vector<Ort::Value>> ans; std::vector<std::vector<Ort::Value>> ans;

View File

@@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
return config_.wenet_ctc.chunk_size * subsampling_factor_; return config_.wenet_ctc.chunk_size * subsampling_factor_;
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors // Return a vector containing 3 tensors
// - attn_cache // - attn_cache

View File

@@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
ans.reserve(states[0].size()); ans.reserve(states[0].size());
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
// cached_len // cached_len
for (int32_t i = 0; i != num_encoders; ++i) { for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][i]; buf[n] = &states[n][i];
} }
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1) auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders + i]; buf[n] = &states[n][num_encoders + i];
} }
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 2 + i]; buf[n] = &states[n][num_encoders * 2 + i];
} }
// (num_layers, left_context_len, 1, attention_dims) // (num_layers, left_context_len, 1, attention_dims)
auto v = Cat(allocator_, buf, 2); auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 3 + i]; buf[n] = &states[n][num_encoders * 3 + i];
} }
// (num_layers, left_context_len, 1, attention_dims/2) // (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2); auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 4 + i]; buf[n] = &states[n][num_encoders * 4 + i];
} }
// (num_layers, left_context_len, 1, attention_dims/2) // (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2); auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 5 + i]; buf[n] = &states[n][num_encoders * 5 + i];
} }
// (num_layers, 1, encoder_dims, cnn_module_kernels-1) // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 6 + i]; buf[n] = &states[n][num_encoders * 6 + i];
} }
// (num_layers, 1, encoder_dims, cnn_module_kernels-1) // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size(); int32_t num_encoders = num_encoder_layers_.size();
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans; std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size); ans.resize(batch_size);
// cached_len // cached_len
for (int32_t i = 0; i != num_encoders; ++i) { for (int32_t i = 0; i != num_encoders; ++i) {
auto v = Unbind<int64_t>(allocator_, &states[i], 1); auto v = Unbind<int64_t>(allocator, &states[i], 1);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_avg // cached_avg
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1); auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_key // cached_key
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2); auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val // cached_val
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2); auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val2 // cached_val2
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2); auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv1 // cached_conv1
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1); auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv2 // cached_conv2
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1); auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size); assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {

View File

@@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
int32_t ChunkShift() const { return decode_chunk_len_; } int32_t ChunkShift() const { return decode_chunk_len_; }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors // Return a vector containing 3 tensors
// - attn_cache // - attn_cache
@@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
} }
std::vector<Ort::Value> StackStates( std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const { std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size()); int32_t batch_size = static_cast<int32_t>(states.size());
std::vector<const Ort::Value *> buf(batch_size); std::vector<const Ort::Value *> buf(batch_size);
@@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
} }
std::vector<std::vector<Ort::Value>> UnStackStates( std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const { std::vector<Ort::Value> states) {
int32_t m = std::accumulate(num_encoder_layers_.begin(), int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0); num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2); assert(states.size() == m * 6 + 2);

View File

@@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
std::vector<const Ort::Value *> buf(batch_size); std::vector<const Ort::Value *> buf(batch_size);
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
int32_t num_states = static_cast<int32_t>(states[0].size()); int32_t num_states = static_cast<int32_t>(states[0].size());
ans.reserve(num_states); ans.reserve(num_states);
@@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i]; buf[n] = &states[n][6 * i];
} }
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
{ {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 1]; buf[n] = &states[n][6 * i + 1];
} }
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
{ {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 2]; buf[n] = &states[n][6 * i + 2];
} }
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
{ {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 3]; buf[n] = &states[n][6 * i + 3];
} }
auto v = Cat(allocator_, buf, 1); auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
{ {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 4]; buf[n] = &states[n][6 * i + 4];
} }
auto v = Cat(allocator_, buf, 0); auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
{ {
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 5]; buf[n] = &states[n][6 * i + 5];
} }
auto v = Cat(allocator_, buf, 0); auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
} }
@@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 2]; buf[n] = &states[n][num_states - 2];
} }
auto v = Cat(allocator_, buf, 0); auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
@@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 1]; buf[n] = &states[n][num_states - 1];
} }
auto v = Cat<int64_t>(allocator_, buf, 0); auto v = Cat<int64_t>(allocator, buf, 0);
ans.push_back(std::move(v)); ans.push_back(std::move(v));
} }
return ans; return ans;
@@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans; std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size); ans.resize(batch_size);
for (int32_t i = 0; i != m; ++i) { for (int32_t i = 0; i != m; ++i) {
{ {
auto v = Unbind(allocator_, &states[i * 6], 1); auto v = Unbind(allocator, &states[i * 6], 1);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind(allocator_, &states[i * 6 + 1], 1); auto v = Unbind(allocator, &states[i * 6 + 1], 1);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind(allocator_, &states[i * 6 + 2], 1); auto v = Unbind(allocator, &states[i * 6 + 2], 1);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind(allocator_, &states[i * 6 + 3], 1); auto v = Unbind(allocator, &states[i * 6 + 3], 1);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind(allocator_, &states[i * 6 + 4], 0); auto v = Unbind(allocator, &states[i * 6 + 4], 0);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind(allocator_, &states[i * 6 + 5], 0); auto v = Unbind(allocator, &states[i * 6 + 5], 0);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
{ {
auto v = Unbind(allocator_, &states[m * 6], 0); auto v = Unbind(allocator, &states[m * 6], 0);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {
@@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
} }
} }
{ {
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0); auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
assert(static_cast<int32_t>(v.size()) == batch_size); assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) { for (int32_t n = 0; n != batch_size; ++n) {

View File

@@ -21,6 +21,36 @@
namespace sherpa_onnx { namespace sherpa_onnx {
static std::string GetInputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetInputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetInputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
static std::string GetOutputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetOutputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetOutputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) { std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
@@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
input_names->resize(node_count); input_names->resize(node_count);
input_names_ptr->resize(node_count); input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) { for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator); (*input_names)[i] = GetInputName(sess, i, allocator);
(*input_names)[i] = tmp.get();
(*input_names_ptr)[i] = (*input_names)[i].c_str(); (*input_names_ptr)[i] = (*input_names)[i].c_str();
} }
} }
@@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
output_names->resize(node_count); output_names->resize(node_count);
output_names_ptr->resize(node_count); output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) { for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator); (*output_names)[i] = GetOutputName(sess, i, allocator);
(*output_names)[i] = tmp.get();
(*output_names_ptr)[i] = (*output_names)[i].c_str(); (*output_names_ptr)[i] = (*output_names)[i].c_str();
} }
} }
@@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
#if ORT_API_VERSION >= 17
std::vector<Ort::AllocatedStringPtr> v = std::vector<Ort::AllocatedStringPtr> v =
meta_data.GetCustomMetadataMapKeysAllocated(allocator); meta_data.GetCustomMetadataMapKeysAllocated(allocator);
for (const auto &key : v) { for (const auto &key : v) {
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
os << key.get() << "=" << p.get() << "\n"; os << key.get() << "=" << p.get() << "\n";
} }
#else
int64_t num_keys = 0;
char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
for (int32_t i = 0; i < num_keys; ++i) {
auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
os << keys[i] << "=" << v << "\n";
allocator.Free(keys[i]);
}
allocator.Free(keys);
#endif
} }
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
@@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
return ans; return ans;
} }
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t); int32_t t);
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key, OrtAllocator *allocator);
void PrintModelMetadata(std::ostream &os, void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT const Ort::ModelMetadata &meta_data); // NOLINT

View File

@@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
case Provider::kCPU: case Provider::kCPU:
break; // nothing to do for the CPU provider break; // nothing to do for the CPU provider
case Provider::kXnnpack: { case Provider::kXnnpack: {
#if ORT_API_VERSION >= 17
if (std::find(available_providers.begin(), available_providers.end(), if (std::find(available_providers.begin(), available_providers.end(),
"XnnpackExecutionProvider") != available_providers.end()) { "XnnpackExecutionProvider") != available_providers.end()) {
sess_opts.AppendExecutionProvider("XNNPACK"); sess_opts.AppendExecutionProvider("XNNPACK");
@@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
os.str().c_str()); os.str().c_str());
} }
#else
SHERPA_ONNX_LOGE(
"Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
static_cast<int32_t>(ORT_API_VERSION));
#endif
break; break;
} }
case Provider::kTRT: { case Provider::kTRT: {

View File

@@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto model_type = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("framework", allocator); LookupCustomModelMetaData(meta_data, "framework", allocator);
if (!model_type) { if (model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n" "No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n" "Please make sure you have added metadata to the model.\n\n"
@@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown; return ModelType::kUnknown;
} }
if (model_type.get() == std::string("wespeaker")) { if (model_type == "wespeaker") {
return ModelType::kWeSpeaker; return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) { } else if (model_type == "3d-speaker") {
return ModelType::k3dSpeaker; return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) { } else if (model_type == "nemo") {
return ModelType::kNeMo; return ModelType::kNeMo;
} else { } else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown; return ModelType::kUnknown;
} }
} }

View File

@@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
return std::move(outputs[0]); return std::move(outputs[0]);
} }
OrtAllocator *Allocator() const { return allocator_; } OrtAllocator *Allocator() { return allocator_; }
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
return meta_data_; return meta_data_;

View File

@@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto model_type = auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type) { if (model_type.empty()) {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n" "No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n" "Please make sure you have added metadata to the model.\n\n"
@@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown; return ModelType::kUnknown;
} }
auto model_type_str = std::string(model_type.get()); if (model_type.find("whisper") == 0) {
if (model_type_str.find("whisper") == 0) {
return ModelType::kWhisper; return ModelType::kWhisper;
} else { } else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown; return ModelType::kUnknown;
} }
} }

View File

@@ -29,20 +29,19 @@ namespace {
const char *ws = " \t\n\r\f\v"; const char *ws = " \t\n\r\f\v";
// trim from end of string (right) // trim from end of string (right)
inline std::string &TrimRight(std::string &s, const char *t = ws) { inline void TrimRight(std::string *s, const char *t = ws) {
s.erase(s.find_last_not_of(t) + 1); s->erase(s->find_last_not_of(t) + 1);
return s;
} }
// trim from beginning of string (left) // trim from beginning of string (left)
inline std::string &TrimLeft(std::string &s, const char *t = ws) { inline void TrimLeft(std::string *s, const char *t = ws) {
s.erase(0, s.find_first_not_of(t)); s->erase(0, s->find_first_not_of(t));
return s;
} }
// trim from both ends of string (right then left) // trim from both ends of string (right then left)
inline std::string &Trim(std::string &s, const char *t = ws) { inline void Trim(std::string *s, const char *t = ws) {
return TrimLeft(TrimRight(s, t), t); TrimRight(s, t);
TrimLeft(s, t);
} }
} // namespace } // namespace
@@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
std::string sym; std::string sym;
int32_t id = -1; int32_t id = -1;
while (std::getline(is, line)) { while (std::getline(is, line)) {
Trim(line); Trim(&line);
std::istringstream iss(line); std::istringstream iss(line);
iss >> sym; iss >> sym;
if (iss.eof()) { if (iss.eof()) {