Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)
Thanks to @Peakyxh for providing pre-built onnxruntime libraries with CUDA support for Linux aarch64. Tested on Jetson nano b01
This commit is contained in:
65
.github/workflows/aarch64-linux-gnu-shared.yaml
vendored
65
.github/workflows/aarch64-linux-gnu-shared.yaml
vendored
@@ -34,11 +34,12 @@ concurrency:
|
||||
jobs:
|
||||
aarch64_linux_gnu_shared:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: aarch64 shared lib test
|
||||
name: aarch64 shared GPU ${{ matrix.gpu }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
gpu: [ON, OFF]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -79,15 +80,24 @@ jobs:
|
||||
make -j2
|
||||
make install
|
||||
|
||||
- name: cache-toolchain
|
||||
id: cache-toolchain
|
||||
- name: cache-toolchain (CPU)
|
||||
if: matrix.gpu == 'OFF'
|
||||
id: cache-toolchain-cpu
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: toolchain
|
||||
key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
||||
|
||||
- name: Download toolchain
|
||||
if: steps.cache-toolchain.outputs.cache-hit != 'true'
|
||||
- name: cache-toolchain (GPU)
|
||||
if: matrix.gpu == 'ON'
|
||||
id: cache-toolchain-gpu
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: toolchain
|
||||
key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
|
||||
|
||||
- name: Download toolchain (CPU, gcc 7.5)
|
||||
if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF'
|
||||
shell: bash
|
||||
run: |
|
||||
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
||||
@@ -95,6 +105,15 @@ jobs:
|
||||
mkdir $GITHUB_WORKSPACE/toolchain
|
||||
tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
|
||||
|
||||
- name: Download toolchain (GPU, gcc 10.3)
|
||||
if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON'
|
||||
shell: bash
|
||||
run: |
|
||||
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
|
||||
|
||||
mkdir $GITHUB_WORKSPACE/toolchain
|
||||
tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
|
||||
|
||||
- name: Set environment variable
|
||||
if: steps.cache-build-result.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
@@ -103,19 +122,31 @@ jobs:
|
||||
echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
|
||||
ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
|
||||
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
|
||||
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV"
|
||||
echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Display toolchain info
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
which aarch64-linux-gnu-gcc
|
||||
aarch64-linux-gnu-gcc --version
|
||||
else
|
||||
which aarch64-none-linux-gnu-gcc
|
||||
aarch64-none-linux-gnu-gcc --version
|
||||
fi
|
||||
|
||||
- name: Display qemu-aarch64 -h
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
|
||||
qemu-aarch64 -h
|
||||
|
||||
- name: build aarch64-linux-gnu
|
||||
@@ -127,6 +158,7 @@ jobs:
|
||||
cmake --version
|
||||
|
||||
export BUILD_SHARED_LIBS=ON
|
||||
export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }}
|
||||
|
||||
./build-aarch64-linux-gnu.sh
|
||||
|
||||
@@ -140,7 +172,11 @@ jobs:
|
||||
run: |
|
||||
export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
|
||||
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
||||
else
|
||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
|
||||
fi
|
||||
|
||||
ls -lh ./build-aarch64-linux-gnu/bin
|
||||
|
||||
@@ -151,11 +187,20 @@ jobs:
|
||||
- name: Copy files
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
aarch64-linux-gnu-strip --version
|
||||
else
|
||||
aarch64-none-linux-gnu-strip --version
|
||||
fi
|
||||
|
||||
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
|
||||
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
dst=${dst}-cpu
|
||||
else
|
||||
dst=${dst}-gpu
|
||||
fi
|
||||
mkdir $dst
|
||||
|
||||
cp -a build-aarch64-linux-gnu/install/bin $dst/
|
||||
@@ -166,7 +211,11 @@ jobs:
|
||||
|
||||
ls -lh $dst/bin/
|
||||
echo "strip"
|
||||
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||
aarch64-linux-gnu-strip $dst/bin/*
|
||||
else
|
||||
aarch64-none-linux-gnu-strip $dst/bin/*
|
||||
fi
|
||||
|
||||
tree $dst
|
||||
|
||||
@@ -174,8 +223,8 @@ jobs:
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: sherpa-onnx-linux-aarch64-shared
|
||||
path: sherpa-onnx-*linux-aarch64-shared.tar.bz2
|
||||
name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }}
|
||||
path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2
|
||||
|
||||
# https://huggingface.co/docs/hub/spaces-github-actions
|
||||
- name: Publish to huggingface
|
||||
@@ -198,7 +247,7 @@ jobs:
|
||||
cd huggingface
|
||||
mkdir -p aarch64
|
||||
|
||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
|
||||
cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
@@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
|
||||
BUILD_SHARED_LIBS=OFF
|
||||
fi
|
||||
|
||||
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
|
||||
# By default, use CPU
|
||||
SHERPA_ONNX_ENABLE_GPU=OFF
|
||||
|
||||
# If you use GPU, then please make sure you have NVIDIA GPUs on your board.
|
||||
# It uses onnxruntime 1.11.0.
|
||||
#
|
||||
# Tested on Jetson Nano B01
|
||||
fi
|
||||
|
||||
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then
|
||||
# Build shared libs if building GPU is enabled.
|
||||
BUILD_SHARED_LIBS=ON
|
||||
fi
|
||||
|
||||
cmake \
|
||||
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
|
||||
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
|
||||
@@ -51,6 +66,7 @@ cmake \
|
||||
-DBUILD_ESPEAK_NG_TESTS=OFF \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \
|
||||
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
|
||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||
|
||||
101
cmake/onnxruntime-linux-aarch64-gpu.cmake
Normal file
101
cmake/onnxruntime-linux-aarch64-gpu.cmake
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
|
||||
if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||
message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||
message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_SHARED_LIBS)
|
||||
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
|
||||
endif()
|
||||
|
||||
if(NOT SHERPA_ONNX_ENABLE_GPU)
|
||||
message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
|
||||
endif()
|
||||
|
||||
set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
|
||||
set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
|
||||
set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please download onnxruntime to one of the following locations.
|
||||
# You can add more if you want.
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||
${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||
${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||
/tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(onnxruntime_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
|
||||
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
|
||||
set(onnxruntime_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(onnxruntime
|
||||
URL
|
||||
${onnxruntime_URL}
|
||||
${onnxruntime_URL2}
|
||||
URL_HASH ${onnxruntime_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(onnxruntime)
|
||||
if(NOT onnxruntime_POPULATED)
|
||||
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
|
||||
FetchContent_Populate(onnxruntime)
|
||||
endif()
|
||||
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||
|
||||
find_library(location_onnxruntime onnxruntime
|
||||
PATHS
|
||||
"${onnxruntime_SOURCE_DIR}/lib"
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
)
|
||||
|
||||
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||
|
||||
add_library(onnxruntime SHARED IMPORTED)
|
||||
|
||||
set_target_properties(onnxruntime PROPERTIES
|
||||
IMPORTED_LOCATION ${location_onnxruntime}
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
||||
)
|
||||
|
||||
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
|
||||
PATHS
|
||||
"${onnxruntime_SOURCE_DIR}/lib"
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
)
|
||||
|
||||
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
|
||||
set_target_properties(onnxruntime_providers_cuda PROPERTIES
|
||||
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
|
||||
)
|
||||
message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
|
||||
|
||||
# for libonnxruntime_providers_shared.so
|
||||
find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared
|
||||
PATHS
|
||||
"${onnxruntime_SOURCE_DIR}/lib"
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
)
|
||||
add_library(onnxruntime_providers_shared SHARED IMPORTED)
|
||||
set_target_properties(onnxruntime_providers_shared PROPERTIES
|
||||
IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib}
|
||||
)
|
||||
message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}")
|
||||
|
||||
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*")
|
||||
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
|
||||
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
|
||||
@@ -13,7 +13,9 @@ function(download_onnxruntime)
|
||||
include(onnxruntime-linux-riscv64-static)
|
||||
endif()
|
||||
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||
if(BUILD_SHARED_LIBS)
|
||||
if(SHERPA_ONNX_ENABLE_GPU)
|
||||
include(onnxruntime-linux-aarch64-gpu)
|
||||
elseif(BUILD_SHARED_LIBS)
|
||||
include(onnxruntime-linux-aarch64)
|
||||
else()
|
||||
include(onnxruntime-linux-aarch64-static)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
function(download_piper_phonemize)
|
||||
include(FetchContent)
|
||||
|
||||
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip")
|
||||
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip")
|
||||
set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14")
|
||||
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
|
||||
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
|
||||
set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-decoder
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
||||
${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
||||
${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
||||
/tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
||||
/star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
||||
$ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||
${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||
${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||
/tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||
/star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 8
|
||||
#include "android/log.h"
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
@@ -38,14 +40,13 @@
|
||||
// Read an integer
|
||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
dst = atoi(value.get()); \
|
||||
dst = atoi(value.c_str()); \
|
||||
if (dst < 0) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
||||
exit(-1); \
|
||||
@@ -54,12 +55,11 @@
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
dst = default_value; \
|
||||
} else { \
|
||||
dst = atoi(value.get()); \
|
||||
dst = atoi(value.c_str()); \
|
||||
if (dst < 0) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
||||
exit(-1); \
|
||||
@@ -70,16 +70,15 @@
|
||||
// read a vector of integers
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
|
||||
bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
|
||||
if (!ret) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
@@ -87,16 +86,15 @@
|
||||
// read a vector of floats
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
|
||||
bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
|
||||
if (!ret) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
@@ -104,17 +102,16 @@
|
||||
// read a vector of strings
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
SplitStringToVector(value.get(), ",", false, &dst); \
|
||||
SplitStringToVector(value.c_str(), ",", false, &dst); \
|
||||
\
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
||||
value.get(), src_key); \
|
||||
value.c_str(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
@@ -122,17 +119,16 @@
|
||||
// read a vector of strings separated by sep
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
SplitStringToVector(value.get(), sep, false, &dst); \
|
||||
SplitStringToVector(value.c_str(), sep, false, &dst); \
|
||||
\
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
||||
value.get(), src_key); \
|
||||
value.c_str(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
@@ -140,14 +136,13 @@
|
||||
// Read a string
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
dst = value.get(); \
|
||||
dst = std::move(value); \
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
||||
exit(-1); \
|
||||
@@ -156,25 +151,23 @@
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
dst = value.get(); \
|
||||
dst = std::move(value); \
|
||||
} while (0)
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
|
||||
default_value) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
dst = default_value; \
|
||||
} else { \
|
||||
dst = value.get(); \
|
||||
dst = std::move(value); \
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
||||
exit(-1); \
|
||||
|
||||
@@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
|
||||
|
||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
|
||||
return std::move(ans[0]);
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
|
||||
return meta_data_;
|
||||
|
||||
@@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||
if (model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"If you are using models from NeMo, please refer to\n"
|
||||
@@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||
if (model_type == "EncDecCTCModelBPE") {
|
||||
return ModelType::kEncDecCTCModelBPE;
|
||||
} else if (model_type.get() == std::string("EncDecCTCModel")) {
|
||||
} else if (model_type == "EncDecCTCModel") {
|
||||
return ModelType::kEncDecCTCModel;
|
||||
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
|
||||
} else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
|
||||
return ModelType::kEncDecHybridRNNTCTCBPEModel;
|
||||
} else if (model_type.get() == std::string("tdnn")) {
|
||||
} else if (model_type == "tdnn") {
|
||||
return ModelType::kTdnn;
|
||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
||||
} else if (model_type == "zipformer2_ctc") {
|
||||
return ModelType::kZipformerCtc;
|
||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||
} else if (model_type == "wenet_ctc") {
|
||||
return ModelType::kWenetCtc;
|
||||
} else if (model_type.get() == std::string("telespeech_ctc")) {
|
||||
} else if (model_type == "telespeech_ctc") {
|
||||
return ModelType::kTeleSpeechCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
|
||||
return {std::move(cached_decoder_out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitPreprocessor(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
|
||||
|
||||
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
auto model_type_ptr =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type_ptr) {
|
||||
auto model_type =
|
||||
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||
if (!model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n\n"
|
||||
"Please refer to the following URLs to add metadata"
|
||||
@@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
std::string model_type(model_type_ptr.get());
|
||||
|
||||
if (model_type == "conformer" || model_type == "zipformer" ||
|
||||
model_type == "zipformer2") {
|
||||
@@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
auto model_type_ptr =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type_ptr) {
|
||||
auto model_type =
|
||||
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||
if (model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n\n"
|
||||
"Please refer to the following URLs to add metadata"
|
||||
@@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
std::string model_type(model_type_ptr.get());
|
||||
|
||||
if (model_type == "conformer" || model_type == "zipformer" ||
|
||||
model_type == "zipformer2") {
|
||||
|
||||
@@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
int32_t ContextSize() const { return context_size_; }
|
||||
int32_t SubsamplingFactor() const { return 4; }
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||
int32_t end_index) const {
|
||||
int32_t end_index) {
|
||||
assert(end_index <= results.size());
|
||||
|
||||
int32_t batch_size = end_index;
|
||||
@@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
|
||||
}
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||
int32_t end_index) const {
|
||||
int32_t end_index) {
|
||||
assert(end_index <= results.size());
|
||||
|
||||
int32_t batch_size = end_index;
|
||||
|
||||
@@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
@@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
|
||||
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
|
||||
|
||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
int32_t SubsamplingFactor() const { return 4; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
|
||||
return {std::move(ans[0]), std::move(ans[1])};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
|
||||
return meta_data_;
|
||||
|
||||
@@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
|
||||
conv_vec[i] = &states[i][1];
|
||||
}
|
||||
|
||||
Ort::Value attn = Cat(allocator_, attn_vec, 2);
|
||||
Ort::Value conv = Cat(allocator_, conv_vec, 2);
|
||||
auto allocator =
|
||||
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
|
||||
|
||||
Ort::Value attn = Cat(allocator, attn_vec, 2);
|
||||
Ort::Value conv = Cat(allocator, conv_vec, 2);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
@@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
|
||||
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
|
||||
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
|
||||
auto allocator =
|
||||
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
|
||||
|
||||
std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
|
||||
std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
|
||||
|
||||
assert(attn_vec.size() == batch_size);
|
||||
assert(conv_vec.size() == batch_size);
|
||||
|
||||
@@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
||||
h_buf[i] = &states[i][0];
|
||||
c_buf[i] = &states[i][1];
|
||||
}
|
||||
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
|
||||
|
||||
Ort::Value h = Cat(allocator_, h_buf, 1);
|
||||
Ort::Value c = Cat(allocator_, c_buf, 1);
|
||||
Ort::Value h = Cat(allocator, h_buf, 1);
|
||||
Ort::Value c = Cat(allocator, c_buf, 1);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
@@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
|
||||
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
|
||||
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
|
||||
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
|
||||
|
||||
std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
|
||||
std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
|
||||
|
||||
assert(h_vec.size() == batch_size);
|
||||
assert(c_vec.size() == batch_size);
|
||||
|
||||
@@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
|
||||
|
||||
int32_t ChunkShift() const { return chunk_shift_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 3 tensors
|
||||
// - cache_last_channel
|
||||
@@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
std::vector<std::vector<Ort::Value>> states) {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
if (batch_size == 1) {
|
||||
return std::move(states[0]);
|
||||
@@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
|
||||
std::vector<Ort::Value> states) const {
|
||||
assert(states.size() == 3);
|
||||
|
||||
auto allocator = const_cast<Impl *>(this)->allocator_;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
|
||||
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
@@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
|
||||
for (int32_t i = 0; i != 3; ++i) {
|
||||
std::vector<Ort::Value> v;
|
||||
if (i == 2) {
|
||||
v = Unbind<int64_t>(allocator_, &states[i], 0);
|
||||
v = Unbind<int64_t>(allocator, &states[i], 0);
|
||||
} else {
|
||||
v = Unbind(allocator_, &states[i], 0);
|
||||
v = Unbind(allocator, &states[i], 0);
|
||||
}
|
||||
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
@@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
|
||||
|
||||
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/online-rnn-lm.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -77,12 +77,12 @@ class OnlineRnnLM::Impl {
|
||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator, x_shape.data(), x_shape.size());
|
||||
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
||||
std::copy(ys.begin() + context_size + h.cur_scored_pos,
|
||||
ys.end() - 1, p_x);
|
||||
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
|
||||
p_x);
|
||||
|
||||
// streaming forward by NN LM
|
||||
auto out = ScoreToken(std::move(x),
|
||||
Convert(std::move(h.nn_lm_states)));
|
||||
auto out =
|
||||
ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
|
||||
|
||||
// update NN LM score in hyp
|
||||
const float *p_nll = out.first.GetTensorData<float>();
|
||||
@@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
|
||||
}
|
||||
|
||||
// get init states for classic rescore
|
||||
std::vector<Ort::Value> GetInitStates() const {
|
||||
std::vector<Ort::Value> GetInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(init_states_.size());
|
||||
|
||||
@@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
||||
return impl_->ComputeLMScoreSF(scale, hyp);
|
||||
}
|
||||
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||
if (model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||
@@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("conformer")) {
|
||||
if (model_type == "conformer") {
|
||||
return ModelType::kConformer;
|
||||
} else if (model_type.get() == std::string("lstm")) {
|
||||
} else if (model_type == "lstm") {
|
||||
return ModelType::kLstm;
|
||||
} else if (model_type.get() == std::string("zipformer")) {
|
||||
} else if (model_type == "zipformer") {
|
||||
return ModelType::kZipformer;
|
||||
} else if (model_type.get() == std::string("zipformer2")) {
|
||||
} else if (model_type == "zipformer2") {
|
||||
return ModelType::kZipformer2;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
@@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
|
||||
auto allocator = const_cast<Impl *>(this)->allocator_;
|
||||
|
||||
// stack cache_last_channel
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
|
||||
@@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
|
||||
Ort::Value c{nullptr};
|
||||
if (i == 2) {
|
||||
c = Cat<int64_t>(allocator_, buf, 0);
|
||||
c = Cat<int64_t>(allocator, buf, 0);
|
||||
} else {
|
||||
c = Cat(allocator_, buf, 0);
|
||||
c = Cat(allocator, buf, 0);
|
||||
}
|
||||
|
||||
ans.push_back(std::move(c));
|
||||
@@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
std::vector<Ort::Value> states) {
|
||||
assert(states.size() == 3);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
|
||||
@@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
|
||||
return config_.wenet_ctc.chunk_size * subsampling_factor_;
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 3 tensors
|
||||
// - attn_cache
|
||||
|
||||
@@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(states[0].size());
|
||||
|
||||
auto allocator =
|
||||
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
|
||||
|
||||
// cached_len
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][i];
|
||||
}
|
||||
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
|
||||
auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_encoders + i];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
|
||||
auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
buf[n] = &states[n][num_encoders * 2 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
auto v = Cat(allocator, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
buf[n] = &states[n][num_encoders * 3 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
auto v = Cat(allocator, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
buf[n] = &states[n][num_encoders * 4 + i];
|
||||
}
|
||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||
auto v = Cat(allocator_, buf, 2);
|
||||
auto v = Cat(allocator, buf, 2);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
buf[n] = &states[n][num_encoders * 5 + i];
|
||||
}
|
||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
||||
buf[n] = &states[n][num_encoders * 6 + i];
|
||||
}
|
||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
int32_t num_encoders = num_encoder_layers_.size();
|
||||
|
||||
auto allocator =
|
||||
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
ans.resize(batch_size);
|
||||
|
||||
// cached_len
|
||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
|
||||
auto v = Unbind<int64_t>(allocator, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_avg
|
||||
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
auto v = Unbind(allocator, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_key
|
||||
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
auto v = Unbind(allocator, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_val
|
||||
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
auto v = Unbind(allocator, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_val2
|
||||
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 2);
|
||||
auto v = Unbind(allocator, &states[i], 2);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_conv1
|
||||
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
auto v = Unbind(allocator, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
||||
|
||||
// cached_conv2
|
||||
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
|
||||
auto v = Unbind(allocator_, &states[i], 1);
|
||||
auto v = Unbind(allocator, &states[i], 1);
|
||||
assert(v.size() == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
|
||||
@@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
|
||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 3 tensors
|
||||
// - attn_cache
|
||||
@@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
std::vector<std::vector<Ort::Value>> states) {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
@@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
std::vector<Ort::Value> states) {
|
||||
int32_t m = std::accumulate(num_encoder_layers_.begin(),
|
||||
num_encoder_layers_.end(), 0);
|
||||
assert(states.size() == m * 6 + 2);
|
||||
|
||||
@@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
||||
|
||||
std::vector<const Ort::Value *> buf(batch_size);
|
||||
|
||||
auto allocator =
|
||||
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
int32_t num_states = static_cast<int32_t>(states[0].size());
|
||||
ans.reserve(num_states);
|
||||
@@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
{
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i + 1];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
{
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i + 2];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
{
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i + 3];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 1);
|
||||
auto v = Cat(allocator, buf, 1);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
{
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i + 4];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 0);
|
||||
auto v = Cat(allocator, buf, 0);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
{
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][6 * i + 5];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 0);
|
||||
auto v = Cat(allocator, buf, 0);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
}
|
||||
@@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_states - 2];
|
||||
}
|
||||
auto v = Cat(allocator_, buf, 0);
|
||||
auto v = Cat(allocator, buf, 0);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
|
||||
@@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
buf[n] = &states[n][num_states - 1];
|
||||
}
|
||||
auto v = Cat<int64_t>(allocator_, buf, 0);
|
||||
auto v = Cat<int64_t>(allocator, buf, 0);
|
||||
ans.push_back(std::move(v));
|
||||
}
|
||||
return ans;
|
||||
@@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
|
||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
auto allocator =
|
||||
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
ans.resize(batch_size);
|
||||
|
||||
for (int32_t i = 0; i != m; ++i) {
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6], 1);
|
||||
auto v = Unbind(allocator, &states[i * 6], 1);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
|
||||
auto v = Unbind(allocator, &states[i * 6 + 1], 1);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
|
||||
auto v = Unbind(allocator, &states[i * 6 + 2], 1);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
|
||||
auto v = Unbind(allocator, &states[i * 6 + 3], 1);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
|
||||
auto v = Unbind(allocator, &states[i * 6 + 4], 0);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
|
||||
auto v = Unbind(allocator, &states[i * 6 + 5], 0);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
|
||||
{
|
||||
auto v = Unbind(allocator_, &states[m * 6], 0);
|
||||
auto v = Unbind(allocator, &states[m * 6], 0);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
@@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
||||
}
|
||||
}
|
||||
{
|
||||
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
|
||||
auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
|
||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||
|
||||
for (int32_t n = 0; n != batch_size; ++n) {
|
||||
|
||||
@@ -21,6 +21,36 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::string GetInputName(Ort::Session *sess, size_t index,
|
||||
OrtAllocator *allocator) {
|
||||
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||
// For other versions, we may need to change it
|
||||
#if ORT_API_VERSION >= 17
|
||||
auto v = sess->GetInputNameAllocated(index, allocator);
|
||||
return v.get();
|
||||
#else
|
||||
auto v = sess->GetInputName(index, allocator);
|
||||
std::string ans = v;
|
||||
allocator->Free(allocator, v);
|
||||
return ans;
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string GetOutputName(Ort::Session *sess, size_t index,
|
||||
OrtAllocator *allocator) {
|
||||
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||
// For other versions, we may need to change it
|
||||
#if ORT_API_VERSION >= 17
|
||||
auto v = sess->GetOutputNameAllocated(index, allocator);
|
||||
return v.get();
|
||||
#else
|
||||
auto v = sess->GetOutputName(index, allocator);
|
||||
std::string ans = v;
|
||||
allocator->Free(allocator, v);
|
||||
return ans;
|
||||
#endif
|
||||
}
|
||||
|
||||
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
||||
std::vector<const char *> *input_names_ptr) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
@@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
||||
input_names->resize(node_count);
|
||||
input_names_ptr->resize(node_count);
|
||||
for (size_t i = 0; i != node_count; ++i) {
|
||||
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
||||
(*input_names)[i] = tmp.get();
|
||||
(*input_names)[i] = GetInputName(sess, i, allocator);
|
||||
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
||||
}
|
||||
}
|
||||
@@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
||||
output_names->resize(node_count);
|
||||
output_names_ptr->resize(node_count);
|
||||
for (size_t i = 0; i != node_count; ++i) {
|
||||
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
||||
(*output_names)[i] = tmp.get();
|
||||
(*output_names)[i] = GetOutputName(sess, i, allocator);
|
||||
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
||||
}
|
||||
}
|
||||
@@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
||||
|
||||
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
#if ORT_API_VERSION >= 17
|
||||
std::vector<Ort::AllocatedStringPtr> v =
|
||||
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
|
||||
for (const auto &key : v) {
|
||||
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
|
||||
os << key.get() << "=" << p.get() << "\n";
|
||||
}
|
||||
#else
|
||||
int64_t num_keys = 0;
|
||||
char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
|
||||
for (int32_t i = 0; i < num_keys; ++i) {
|
||||
auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
|
||||
os << keys[i] << "=" << v << "\n";
|
||||
allocator.Free(keys[i]);
|
||||
}
|
||||
|
||||
allocator.Free(keys);
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
@@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
|
||||
const char *key,
|
||||
OrtAllocator *allocator) {
|
||||
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||
// For other versions, we may need to change it
|
||||
#if ORT_API_VERSION >= 17
|
||||
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
|
||||
return v.get();
|
||||
#else
|
||||
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
|
||||
std::string ans = v;
|
||||
allocator->Free(allocator, v);
|
||||
return ans;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
||||
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
||||
int32_t t);
|
||||
|
||||
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
|
||||
const char *key, OrtAllocator *allocator);
|
||||
|
||||
void PrintModelMetadata(std::ostream &os,
|
||||
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
|
||||
case Provider::kCPU:
|
||||
break; // nothing to do for the CPU provider
|
||||
case Provider::kXnnpack: {
|
||||
#if ORT_API_VERSION >= 17
|
||||
if (std::find(available_providers.begin(), available_providers.end(),
|
||||
"XnnpackExecutionProvider") != available_providers.end()) {
|
||||
sess_opts.AppendExecutionProvider("XNNPACK");
|
||||
@@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
|
||||
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
|
||||
os.str().c_str());
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
|
||||
static_cast<int32_t>(ORT_API_VERSION));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case Provider::kTRT: {
|
||||
|
||||
@@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
|
||||
if (!model_type) {
|
||||
LookupCustomModelMetaData(meta_data, "framework", allocator);
|
||||
if (model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
@@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("wespeaker")) {
|
||||
if (model_type == "wespeaker") {
|
||||
return ModelType::kWeSpeaker;
|
||||
} else if (model_type.get() == std::string("3d-speaker")) {
|
||||
} else if (model_type == "3d-speaker") {
|
||||
return ModelType::k3dSpeaker;
|
||||
} else if (model_type.get() == std::string("nemo")) {
|
||||
} else if (model_type == "nemo") {
|
||||
return ModelType::kNeMo;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
|
||||
return std::move(outputs[0]);
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
|
||||
return meta_data_;
|
||||
|
||||
@@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||
if (model_type.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
@@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
auto model_type_str = std::string(model_type.get());
|
||||
if (model_type_str.find("whisper") == 0) {
|
||||
if (model_type.find("whisper") == 0) {
|
||||
return ModelType::kWhisper;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,20 +29,19 @@ namespace {
|
||||
const char *ws = " \t\n\r\f\v";
|
||||
|
||||
// trim from end of string (right)
|
||||
inline std::string &TrimRight(std::string &s, const char *t = ws) {
|
||||
s.erase(s.find_last_not_of(t) + 1);
|
||||
return s;
|
||||
inline void TrimRight(std::string *s, const char *t = ws) {
|
||||
s->erase(s->find_last_not_of(t) + 1);
|
||||
}
|
||||
|
||||
// trim from beginning of string (left)
|
||||
inline std::string &TrimLeft(std::string &s, const char *t = ws) {
|
||||
s.erase(0, s.find_first_not_of(t));
|
||||
return s;
|
||||
inline void TrimLeft(std::string *s, const char *t = ws) {
|
||||
s->erase(0, s->find_first_not_of(t));
|
||||
}
|
||||
|
||||
// trim from both ends of string (right then left)
|
||||
inline std::string &Trim(std::string &s, const char *t = ws) {
|
||||
return TrimLeft(TrimRight(s, t), t);
|
||||
inline void Trim(std::string *s, const char *t = ws) {
|
||||
TrimRight(s, t);
|
||||
TrimLeft(s, t);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
|
||||
std::string sym;
|
||||
int32_t id = -1;
|
||||
while (std::getline(is, line)) {
|
||||
Trim(line);
|
||||
Trim(&line);
|
||||
std::istringstream iss(line);
|
||||
iss >> sym;
|
||||
if (iss.eof()) {
|
||||
|
||||
Reference in New Issue
Block a user