Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)
Thanks to @Peakyxh for providing pre-built onnxruntime libraries with CUDA support for Linux aarch64. Tested on Jetson nano b01
This commit is contained in:
77
.github/workflows/aarch64-linux-gnu-shared.yaml
vendored
77
.github/workflows/aarch64-linux-gnu-shared.yaml
vendored
@@ -34,11 +34,12 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
aarch64_linux_gnu_shared:
|
aarch64_linux_gnu_shared:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
name: aarch64 shared lib test
|
name: aarch64 shared GPU ${{ matrix.gpu }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
|
gpu: [ON, OFF]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -79,15 +80,24 @@ jobs:
|
|||||||
make -j2
|
make -j2
|
||||||
make install
|
make install
|
||||||
|
|
||||||
- name: cache-toolchain
|
- name: cache-toolchain (CPU)
|
||||||
id: cache-toolchain
|
if: matrix.gpu == 'OFF'
|
||||||
|
id: cache-toolchain-cpu
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: toolchain
|
path: toolchain
|
||||||
key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
||||||
|
|
||||||
- name: Download toolchain
|
- name: cache-toolchain (GPU)
|
||||||
if: steps.cache-toolchain.outputs.cache-hit != 'true'
|
if: matrix.gpu == 'ON'
|
||||||
|
id: cache-toolchain-gpu
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: toolchain
|
||||||
|
key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
|
||||||
|
|
||||||
|
- name: Download toolchain (CPU, gcc 7.5)
|
||||||
|
if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
||||||
@@ -95,6 +105,15 @@ jobs:
|
|||||||
mkdir $GITHUB_WORKSPACE/toolchain
|
mkdir $GITHUB_WORKSPACE/toolchain
|
||||||
tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
|
tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
|
||||||
|
|
||||||
|
- name: Download toolchain (GPU, gcc 10.3)
|
||||||
|
if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
|
||||||
|
|
||||||
|
mkdir $GITHUB_WORKSPACE/toolchain
|
||||||
|
tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
|
||||||
|
|
||||||
- name: Set environment variable
|
- name: Set environment variable
|
||||||
if: steps.cache-build-result.outputs.cache-hit != 'true'
|
if: steps.cache-build-result.outputs.cache-hit != 'true'
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -103,19 +122,31 @@ jobs:
|
|||||||
echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
|
echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
|
||||||
ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
|
ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
|
||||||
|
|
||||||
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
|
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
|
||||||
|
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
|
||||||
|
else
|
||||||
|
echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV"
|
||||||
|
echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV"
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Display toolchain info
|
- name: Display toolchain info
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
aarch64-linux-gnu-gcc --version
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
|
which aarch64-linux-gnu-gcc
|
||||||
|
aarch64-linux-gnu-gcc --version
|
||||||
|
else
|
||||||
|
which aarch64-none-linux-gnu-gcc
|
||||||
|
aarch64-none-linux-gnu-gcc --version
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Display qemu-aarch64 -h
|
- name: Display qemu-aarch64 -h
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
||||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
||||||
|
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
|
||||||
qemu-aarch64 -h
|
qemu-aarch64 -h
|
||||||
|
|
||||||
- name: build aarch64-linux-gnu
|
- name: build aarch64-linux-gnu
|
||||||
@@ -127,6 +158,7 @@ jobs:
|
|||||||
cmake --version
|
cmake --version
|
||||||
|
|
||||||
export BUILD_SHARED_LIBS=ON
|
export BUILD_SHARED_LIBS=ON
|
||||||
|
export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }}
|
||||||
|
|
||||||
./build-aarch64-linux-gnu.sh
|
./build-aarch64-linux-gnu.sh
|
||||||
|
|
||||||
@@ -140,7 +172,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
|
export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
|
||||||
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
|
||||||
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
|
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
|
||||||
|
else
|
||||||
|
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
|
||||||
|
fi
|
||||||
|
|
||||||
ls -lh ./build-aarch64-linux-gnu/bin
|
ls -lh ./build-aarch64-linux-gnu/bin
|
||||||
|
|
||||||
@@ -151,11 +187,20 @@ jobs:
|
|||||||
- name: Copy files
|
- name: Copy files
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
aarch64-linux-gnu-strip --version
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
|
aarch64-linux-gnu-strip --version
|
||||||
|
else
|
||||||
|
aarch64-none-linux-gnu-strip --version
|
||||||
|
fi
|
||||||
|
|
||||||
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||||
|
|
||||||
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
|
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
|
||||||
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
|
dst=${dst}-cpu
|
||||||
|
else
|
||||||
|
dst=${dst}-gpu
|
||||||
|
fi
|
||||||
mkdir $dst
|
mkdir $dst
|
||||||
|
|
||||||
cp -a build-aarch64-linux-gnu/install/bin $dst/
|
cp -a build-aarch64-linux-gnu/install/bin $dst/
|
||||||
@@ -166,7 +211,11 @@ jobs:
|
|||||||
|
|
||||||
ls -lh $dst/bin/
|
ls -lh $dst/bin/
|
||||||
echo "strip"
|
echo "strip"
|
||||||
aarch64-linux-gnu-strip $dst/bin/*
|
if [[ ${{ matrix.gpu }} == OFF ]]; then
|
||||||
|
aarch64-linux-gnu-strip $dst/bin/*
|
||||||
|
else
|
||||||
|
aarch64-none-linux-gnu-strip $dst/bin/*
|
||||||
|
fi
|
||||||
|
|
||||||
tree $dst
|
tree $dst
|
||||||
|
|
||||||
@@ -174,8 +223,8 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: sherpa-onnx-linux-aarch64-shared
|
name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }}
|
||||||
path: sherpa-onnx-*linux-aarch64-shared.tar.bz2
|
path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2
|
||||||
|
|
||||||
# https://huggingface.co/docs/hub/spaces-github-actions
|
# https://huggingface.co/docs/hub/spaces-github-actions
|
||||||
- name: Publish to huggingface
|
- name: Publish to huggingface
|
||||||
@@ -198,7 +247,7 @@ jobs:
|
|||||||
cd huggingface
|
cd huggingface
|
||||||
mkdir -p aarch64
|
mkdir -p aarch64
|
||||||
|
|
||||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
|
cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
|
||||||
|
|
||||||
git status
|
git status
|
||||||
git lfs track "*.bz2"
|
git lfs track "*.bz2"
|
||||||
|
|||||||
@@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
|
|||||||
BUILD_SHARED_LIBS=OFF
|
BUILD_SHARED_LIBS=OFF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
|
||||||
|
# By default, use CPU
|
||||||
|
SHERPA_ONNX_ENABLE_GPU=OFF
|
||||||
|
|
||||||
|
# If you use GPU, then please make sure you have NVIDIA GPUs on your board.
|
||||||
|
# It uses onnxruntime 1.11.0.
|
||||||
|
#
|
||||||
|
# Tested on Jetson Nano B01
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then
|
||||||
|
# Build shared libs if building GPU is enabled.
|
||||||
|
BUILD_SHARED_LIBS=ON
|
||||||
|
fi
|
||||||
|
|
||||||
cmake \
|
cmake \
|
||||||
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
|
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
|
||||||
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
|
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
|
||||||
@@ -51,6 +66,7 @@ cmake \
|
|||||||
-DBUILD_ESPEAK_NG_TESTS=OFF \
|
-DBUILD_ESPEAK_NG_TESTS=OFF \
|
||||||
-DCMAKE_INSTALL_PREFIX=./install \
|
-DCMAKE_INSTALL_PREFIX=./install \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \
|
||||||
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
|
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
|
||||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||||
|
|||||||
101
cmake/onnxruntime-linux-aarch64-gpu.cmake
Normal file
101
cmake/onnxruntime-linux-aarch64-gpu.cmake
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# Copyright (c) 2022-2024 Xiaomi Corporation
|
||||||
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
|
||||||
|
if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||||
|
message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||||
|
message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT BUILD_SHARED_LIBS)
|
||||||
|
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
|
||||||
|
set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
|
||||||
|
set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7")
|
||||||
|
|
||||||
|
# If you don't have access to the Internet,
|
||||||
|
# please download onnxruntime to one of the following locations.
|
||||||
|
# You can add more if you want.
|
||||||
|
set(possible_file_locations
|
||||||
|
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||||
|
${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||||
|
${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||||
|
/tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||||
|
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
if(EXISTS ${f})
|
||||||
|
set(onnxruntime_URL "${f}")
|
||||||
|
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
|
||||||
|
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
|
||||||
|
set(onnxruntime_URL2)
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
FetchContent_Declare(onnxruntime
|
||||||
|
URL
|
||||||
|
${onnxruntime_URL}
|
||||||
|
${onnxruntime_URL2}
|
||||||
|
URL_HASH ${onnxruntime_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(onnxruntime)
|
||||||
|
if(NOT onnxruntime_POPULATED)
|
||||||
|
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
|
||||||
|
FetchContent_Populate(onnxruntime)
|
||||||
|
endif()
|
||||||
|
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||||
|
|
||||||
|
find_library(location_onnxruntime onnxruntime
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||||
|
|
||||||
|
add_library(onnxruntime SHARED IMPORTED)
|
||||||
|
|
||||||
|
set_target_properties(onnxruntime PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime}
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
||||||
|
)
|
||||||
|
|
||||||
|
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
|
||||||
|
set_target_properties(onnxruntime_providers_cuda PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
|
||||||
|
)
|
||||||
|
message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
|
||||||
|
|
||||||
|
# for libonnxruntime_providers_shared.so
|
||||||
|
find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
add_library(onnxruntime_providers_shared SHARED IMPORTED)
|
||||||
|
set_target_properties(onnxruntime_providers_shared PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib}
|
||||||
|
)
|
||||||
|
message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}")
|
||||||
|
|
||||||
|
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*")
|
||||||
|
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
|
||||||
|
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
|
||||||
@@ -13,7 +13,9 @@ function(download_onnxruntime)
|
|||||||
include(onnxruntime-linux-riscv64-static)
|
include(onnxruntime-linux-riscv64-static)
|
||||||
endif()
|
endif()
|
||||||
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||||
if(BUILD_SHARED_LIBS)
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
include(onnxruntime-linux-aarch64-gpu)
|
||||||
|
elseif(BUILD_SHARED_LIBS)
|
||||||
include(onnxruntime-linux-aarch64)
|
include(onnxruntime-linux-aarch64)
|
||||||
else()
|
else()
|
||||||
include(onnxruntime-linux-aarch64-static)
|
include(onnxruntime-linux-aarch64-static)
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
function(download_piper_phonemize)
|
function(download_piper_phonemize)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip")
|
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
|
||||||
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip")
|
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
|
||||||
set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14")
|
set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b")
|
||||||
|
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-decoder
|
# please pre-download kaldi-decoder
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
$ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||||
${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||||
${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||||
/tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
/tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||||
/star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
|
/star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 8
|
#if __ANDROID_API__ >= 8
|
||||||
#include "android/log.h"
|
#include "android/log.h"
|
||||||
#define SHERPA_ONNX_LOGE(...) \
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
@@ -36,30 +38,28 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Read an integer
|
// Read an integer
|
||||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
\
|
||||||
\
|
dst = atoi(value.c_str()); \
|
||||||
dst = atoi(value.get()); \
|
if (dst < 0) { \
|
||||||
if (dst < 0) { \
|
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
||||||
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
|
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
|
||||||
dst = default_value; \
|
dst = default_value; \
|
||||||
} else { \
|
} else { \
|
||||||
dst = atoi(value.get()); \
|
dst = atoi(value.c_str()); \
|
||||||
if (dst < 0) { \
|
if (dst < 0) { \
|
||||||
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
|
||||||
exit(-1); \
|
exit(-1); \
|
||||||
@@ -68,118 +68,111 @@
|
|||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// read a vector of integers
|
// read a vector of integers
|
||||||
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
\
|
||||||
\
|
bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
|
||||||
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
|
if (!ret) { \
|
||||||
if (!ret) { \
|
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
|
||||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// read a vector of floats
|
// read a vector of floats
|
||||||
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
\
|
||||||
\
|
bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
|
||||||
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
|
if (!ret) { \
|
||||||
if (!ret) { \
|
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
|
||||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// read a vector of strings
|
// read a vector of strings
|
||||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
SplitStringToVector(value.c_str(), ",", false, &dst); \
|
||||||
SplitStringToVector(value.get(), ",", false, &dst); \
|
\
|
||||||
\
|
if (dst.empty()) { \
|
||||||
if (dst.empty()) { \
|
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
||||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
value.c_str(), src_key); \
|
||||||
value.get(), src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// read a vector of strings separated by sep
|
// read a vector of strings separated by sep
|
||||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
|
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
SplitStringToVector(value.c_str(), sep, false, &dst); \
|
||||||
SplitStringToVector(value.get(), sep, false, &dst); \
|
\
|
||||||
\
|
if (dst.empty()) { \
|
||||||
if (dst.empty()) { \
|
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
||||||
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
|
value.c_str(), src_key); \
|
||||||
value.get(), src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// Read a string
|
// Read a string
|
||||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
\
|
||||||
\
|
dst = std::move(value); \
|
||||||
dst = value.get(); \
|
if (dst.empty()) { \
|
||||||
if (dst.empty()) { \
|
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
||||||
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
\
|
||||||
\
|
dst = std::move(value); \
|
||||||
dst = value.get(); \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
|
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
|
||||||
default_value) \
|
default_value) \
|
||||||
do { \
|
do { \
|
||||||
auto value = \
|
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
if (value.empty()) { \
|
||||||
if (!value) { \
|
dst = default_value; \
|
||||||
dst = default_value; \
|
} else { \
|
||||||
} else { \
|
dst = std::move(value); \
|
||||||
dst = value.get(); \
|
if (dst.empty()) { \
|
||||||
if (dst.empty()) { \
|
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
||||||
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
|
exit(-1); \
|
||||||
exit(-1); \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define SHERPA_ONNX_EXIT(code) exit(code)
|
#define SHERPA_ONNX_EXIT(code) exit(code)
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
|
|||||||
|
|
||||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
|
|||||||
return std::move(ans[0]);
|
return std::move(ans[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
|
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
|
||||||
return meta_data_;
|
return meta_data_;
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto model_type =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||||
if (!model_type) {
|
if (model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n"
|
"No model_type in the metadata!\n"
|
||||||
"If you are using models from NeMo, please refer to\n"
|
"If you are using models from NeMo, please refer to\n"
|
||||||
@@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
if (model_type == "EncDecCTCModelBPE") {
|
||||||
return ModelType::kEncDecCTCModelBPE;
|
return ModelType::kEncDecCTCModelBPE;
|
||||||
} else if (model_type.get() == std::string("EncDecCTCModel")) {
|
} else if (model_type == "EncDecCTCModel") {
|
||||||
return ModelType::kEncDecCTCModel;
|
return ModelType::kEncDecCTCModel;
|
||||||
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
|
} else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
|
||||||
return ModelType::kEncDecHybridRNNTCTCBPEModel;
|
return ModelType::kEncDecHybridRNNTCTCBPEModel;
|
||||||
} else if (model_type.get() == std::string("tdnn")) {
|
} else if (model_type == "tdnn") {
|
||||||
return ModelType::kTdnn;
|
return ModelType::kTdnn;
|
||||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
} else if (model_type == "zipformer2_ctc") {
|
||||||
return ModelType::kZipformerCtc;
|
return ModelType::kZipformerCtc;
|
||||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
} else if (model_type == "wenet_ctc") {
|
||||||
return ModelType::kWenetCtc;
|
return ModelType::kWenetCtc;
|
||||||
} else if (model_type.get() == std::string("telespeech_ctc")) {
|
} else if (model_type == "telespeech_ctc") {
|
||||||
return ModelType::kTeleSpeechCtc;
|
return ModelType::kTeleSpeechCtc;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
|
|||||||
return {std::move(cached_decoder_out[0]), std::move(next_states)};
|
return {std::move(cached_decoder_out[0]), std::move(next_states)};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitPreprocessor(void *model_data, size_t model_data_length) {
|
void InitPreprocessor(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
|
|||||||
|
|
||||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
|
|||||||
|
|
||||||
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
|
||||||
auto model_type_ptr =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||||
if (!model_type_ptr) {
|
if (!model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n\n"
|
"No model_type in the metadata!\n\n"
|
||||||
"Please refer to the following URLs to add metadata"
|
"Please refer to the following URLs to add metadata"
|
||||||
@@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
std::string model_type(model_type_ptr.get());
|
|
||||||
|
|
||||||
if (model_type == "conformer" || model_type == "zipformer" ||
|
if (model_type == "conformer" || model_type == "zipformer" ||
|
||||||
model_type == "zipformer2") {
|
model_type == "zipformer2") {
|
||||||
@@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
|
||||||
auto model_type_ptr =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||||
if (!model_type_ptr) {
|
if (model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n\n"
|
"No model_type in the metadata!\n\n"
|
||||||
"Please refer to the following URLs to add metadata"
|
"Please refer to the following URLs to add metadata"
|
||||||
@@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
std::string model_type(model_type_ptr.get());
|
|
||||||
|
|
||||||
if (model_type == "conformer" || model_type == "zipformer" ||
|
if (model_type == "conformer" || model_type == "zipformer" ||
|
||||||
model_type == "zipformer2") {
|
model_type == "zipformer2") {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
|
|||||||
return meta_data_;
|
return meta_data_;
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
|
|||||||
|
|
||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
|
|||||||
|
|
||||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
|
|||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
int32_t ContextSize() const { return context_size_; }
|
int32_t ContextSize() const { return context_size_; }
|
||||||
int32_t SubsamplingFactor() const { return 4; }
|
int32_t SubsamplingFactor() const { return 4; }
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
Ort::Value BuildDecoderInput(
|
Ort::Value BuildDecoderInput(
|
||||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||||
int32_t end_index) const {
|
int32_t end_index) {
|
||||||
assert(end_index <= results.size());
|
assert(end_index <= results.size());
|
||||||
|
|
||||||
int32_t batch_size = end_index;
|
int32_t batch_size = end_index;
|
||||||
@@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||||
int32_t end_index) const {
|
int32_t end_index) {
|
||||||
assert(end_index <= results.size());
|
assert(end_index <= results.size());
|
||||||
|
|
||||||
int32_t batch_size = end_index;
|
int32_t batch_size = end_index;
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
|
|||||||
return std::move(logit[0]);
|
return std::move(logit[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
|
||||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||||
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||||
s0_shape.size());
|
s0_shape.size());
|
||||||
@@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
|
|||||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
|
|||||||
|
|
||||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
|
|||||||
|
|
||||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
|
|||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
int32_t SubsamplingFactor() const { return 4; }
|
int32_t SubsamplingFactor() const { return 4; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
|
|||||||
return {std::move(ans[0]), std::move(ans[1])};
|
return {std::move(ans[0]), std::move(ans[1])};
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
|
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
|
||||||
return meta_data_;
|
return meta_data_;
|
||||||
|
|||||||
@@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
|
|||||||
conv_vec[i] = &states[i][1];
|
conv_vec[i] = &states[i][1];
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value attn = Cat(allocator_, attn_vec, 2);
|
auto allocator =
|
||||||
Ort::Value conv = Cat(allocator_, conv_vec, 2);
|
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
|
Ort::Value attn = Cat(allocator, attn_vec, 2);
|
||||||
|
Ort::Value conv = Cat(allocator, conv_vec, 2);
|
||||||
|
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(2);
|
ans.reserve(2);
|
||||||
@@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||||
|
|
||||||
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
|
auto allocator =
|
||||||
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
|
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
|
||||||
|
std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
|
||||||
|
|
||||||
assert(attn_vec.size() == batch_size);
|
assert(attn_vec.size() == batch_size);
|
||||||
assert(conv_vec.size() == batch_size);
|
assert(conv_vec.size() == batch_size);
|
||||||
|
|||||||
@@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
|
|||||||
h_buf[i] = &states[i][0];
|
h_buf[i] = &states[i][0];
|
||||||
c_buf[i] = &states[i][1];
|
c_buf[i] = &states[i][1];
|
||||||
}
|
}
|
||||||
|
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
Ort::Value h = Cat(allocator_, h_buf, 1);
|
Ort::Value h = Cat(allocator, h_buf, 1);
|
||||||
Ort::Value c = Cat(allocator_, c_buf, 1);
|
Ort::Value c = Cat(allocator, c_buf, 1);
|
||||||
|
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(2);
|
ans.reserve(2);
|
||||||
@@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||||
|
|
||||||
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
|
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
|
||||||
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
|
|
||||||
|
std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
|
||||||
|
std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
|
||||||
|
|
||||||
assert(h_vec.size() == batch_size);
|
assert(h_vec.size() == batch_size);
|
||||||
assert(c_vec.size() == batch_size);
|
assert(c_vec.size() == batch_size);
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
|
|||||||
|
|
||||||
int32_t ChunkShift() const { return chunk_shift_; }
|
int32_t ChunkShift() const { return chunk_shift_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
// Return a vector containing 3 tensors
|
// Return a vector containing 3 tensors
|
||||||
// - cache_last_channel
|
// - cache_last_channel
|
||||||
@@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> StackStates(
|
std::vector<Ort::Value> StackStates(
|
||||||
std::vector<std::vector<Ort::Value>> states) const {
|
std::vector<std::vector<Ort::Value>> states) {
|
||||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||||
if (batch_size == 1) {
|
if (batch_size == 1) {
|
||||||
return std::move(states[0]);
|
return std::move(states[0]);
|
||||||
@@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
|
|||||||
std::vector<Ort::Value> states) const {
|
std::vector<Ort::Value> states) const {
|
||||||
assert(states.size() == 3);
|
assert(states.size() == 3);
|
||||||
|
|
||||||
|
auto allocator = const_cast<Impl *>(this)->allocator_;
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans;
|
std::vector<std::vector<Ort::Value>> ans;
|
||||||
|
|
||||||
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||||
@@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
|
|||||||
for (int32_t i = 0; i != 3; ++i) {
|
for (int32_t i = 0; i != 3; ++i) {
|
||||||
std::vector<Ort::Value> v;
|
std::vector<Ort::Value> v;
|
||||||
if (i == 2) {
|
if (i == 2) {
|
||||||
v = Unbind<int64_t>(allocator_, &states[i], 0);
|
v = Unbind<int64_t>(allocator, &states[i], 0);
|
||||||
} else {
|
} else {
|
||||||
v = Unbind(allocator_, &states[i], 0);
|
v = Unbind(allocator, &states[i], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
|
|||||||
|
|
||||||
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
|||||||
@@ -5,10 +5,10 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-rnn-lm.h"
|
#include "sherpa-onnx/csrc/online-rnn-lm.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -53,49 +53,49 @@ class OnlineRnnLM::Impl {
|
|||||||
|
|
||||||
// classic rescore function
|
// classic rescore function
|
||||||
void ComputeLMScore(float scale, int32_t context_size,
|
void ComputeLMScore(float scale, int32_t context_size,
|
||||||
std::vector<Hypotheses> *hyps) {
|
std::vector<Hypotheses> *hyps) {
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
|
||||||
for (auto &hyp : *hyps) {
|
for (auto &hyp : *hyps) {
|
||||||
for (auto &h_m : hyp) {
|
for (auto &h_m : hyp) {
|
||||||
auto &h = h_m.second;
|
auto &h = h_m.second;
|
||||||
auto &ys = h.ys;
|
auto &ys = h.ys;
|
||||||
const int32_t token_num_in_chunk =
|
const int32_t token_num_in_chunk =
|
||||||
ys.size() - context_size - h.cur_scored_pos - 1;
|
ys.size() - context_size - h.cur_scored_pos - 1;
|
||||||
|
|
||||||
if (token_num_in_chunk < 1) {
|
if (token_num_in_chunk < 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (h.nn_lm_states.empty()) {
|
if (h.nn_lm_states.empty()) {
|
||||||
h.nn_lm_states = Convert(GetInitStates());
|
h.nn_lm_states = Convert(GetInitStates());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
||||||
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
||||||
|
|
||||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
||||||
allocator, x_shape.data(), x_shape.size());
|
allocator, x_shape.data(), x_shape.size());
|
||||||
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
||||||
std::copy(ys.begin() + context_size + h.cur_scored_pos,
|
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
|
||||||
ys.end() - 1, p_x);
|
p_x);
|
||||||
|
|
||||||
// streaming forward by NN LM
|
// streaming forward by NN LM
|
||||||
auto out = ScoreToken(std::move(x),
|
auto out =
|
||||||
Convert(std::move(h.nn_lm_states)));
|
ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
|
||||||
|
|
||||||
// update NN LM score in hyp
|
// update NN LM score in hyp
|
||||||
const float *p_nll = out.first.GetTensorData<float>();
|
const float *p_nll = out.first.GetTensorData<float>();
|
||||||
h.lm_log_prob = -scale * (*p_nll);
|
h.lm_log_prob = -scale * (*p_nll);
|
||||||
|
|
||||||
// update NN LM states in hyp
|
// update NN LM states in hyp
|
||||||
h.nn_lm_states = Convert(std::move(out.second));
|
h.nn_lm_states = Convert(std::move(out.second));
|
||||||
|
|
||||||
h.cur_scored_pos += token_num_in_chunk;
|
h.cur_scored_pos += token_num_in_chunk;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||||
Ort::Value x, std::vector<Ort::Value> states) {
|
Ort::Value x, std::vector<Ort::Value> states) {
|
||||||
@@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get init states for classic rescore
|
// get init states for classic rescore
|
||||||
std::vector<Ort::Value> GetInitStates() const {
|
std::vector<Ort::Value> GetInitStates() {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(init_states_.size());
|
ans.reserve(init_states_.size());
|
||||||
|
|
||||||
@@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
|
|||||||
|
|
||||||
// classic rescore scores
|
// classic rescore scores
|
||||||
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
|
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
|
||||||
std::vector<Hypotheses> *hyps) {
|
std::vector<Hypotheses> *hyps) {
|
||||||
return impl_->ComputeLMScore(scale, context_size, hyps);
|
return impl_->ComputeLMScore(scale, context_size, hyps);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
|
|||||||
return impl_->ComputeLMScoreSF(scale, hyp);
|
return impl_->ComputeLMScoreSF(scale, hyp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto model_type =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||||
if (!model_type) {
|
if (model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n"
|
"No model_type in the metadata!\n"
|
||||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||||
@@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("conformer")) {
|
if (model_type == "conformer") {
|
||||||
return ModelType::kConformer;
|
return ModelType::kConformer;
|
||||||
} else if (model_type.get() == std::string("lstm")) {
|
} else if (model_type == "lstm") {
|
||||||
return ModelType::kLstm;
|
return ModelType::kLstm;
|
||||||
} else if (model_type.get() == std::string("zipformer")) {
|
} else if (model_type == "zipformer") {
|
||||||
return ModelType::kZipformer;
|
return ModelType::kZipformer;
|
||||||
} else if (model_type.get() == std::string("zipformer2")) {
|
} else if (model_type == "zipformer2") {
|
||||||
return ModelType::kZipformer2;
|
return ModelType::kZipformer2;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
|
|||||||
|
|
||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||||
|
|
||||||
@@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
|
|||||||
|
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
|
|
||||||
|
auto allocator = const_cast<Impl *>(this)->allocator_;
|
||||||
|
|
||||||
// stack cache_last_channel
|
// stack cache_last_channel
|
||||||
std::vector<const Ort::Value *> buf(batch_size);
|
std::vector<const Ort::Value *> buf(batch_size);
|
||||||
|
|
||||||
@@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
|
|||||||
|
|
||||||
Ort::Value c{nullptr};
|
Ort::Value c{nullptr};
|
||||||
if (i == 2) {
|
if (i == 2) {
|
||||||
c = Cat<int64_t>(allocator_, buf, 0);
|
c = Cat<int64_t>(allocator, buf, 0);
|
||||||
} else {
|
} else {
|
||||||
c = Cat(allocator_, buf, 0);
|
c = Cat(allocator, buf, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ans.push_back(std::move(c));
|
ans.push_back(std::move(c));
|
||||||
@@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||||
std::vector<Ort::Value> states) const {
|
std::vector<Ort::Value> states) {
|
||||||
assert(states.size() == 3);
|
assert(states.size() == 3);
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans;
|
std::vector<std::vector<Ort::Value>> ans;
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
|
|||||||
return config_.wenet_ctc.chunk_size * subsampling_factor_;
|
return config_.wenet_ctc.chunk_size * subsampling_factor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
// Return a vector containing 3 tensors
|
// Return a vector containing 3 tensors
|
||||||
// - attn_cache
|
// - attn_cache
|
||||||
|
|||||||
@@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(states[0].size());
|
ans.reserve(states[0].size());
|
||||||
|
|
||||||
|
auto allocator =
|
||||||
|
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
// cached_len
|
// cached_len
|
||||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][i];
|
buf[n] = &states[n][i];
|
||||||
}
|
}
|
||||||
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
|
auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][num_encoders + i];
|
buf[n] = &states[n][num_encoders + i];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
|
auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
buf[n] = &states[n][num_encoders * 2 + i];
|
buf[n] = &states[n][num_encoders * 2 + i];
|
||||||
}
|
}
|
||||||
// (num_layers, left_context_len, 1, attention_dims)
|
// (num_layers, left_context_len, 1, attention_dims)
|
||||||
auto v = Cat(allocator_, buf, 2);
|
auto v = Cat(allocator, buf, 2);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
buf[n] = &states[n][num_encoders * 3 + i];
|
buf[n] = &states[n][num_encoders * 3 + i];
|
||||||
}
|
}
|
||||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||||
auto v = Cat(allocator_, buf, 2);
|
auto v = Cat(allocator, buf, 2);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
buf[n] = &states[n][num_encoders * 4 + i];
|
buf[n] = &states[n][num_encoders * 4 + i];
|
||||||
}
|
}
|
||||||
// (num_layers, left_context_len, 1, attention_dims/2)
|
// (num_layers, left_context_len, 1, attention_dims/2)
|
||||||
auto v = Cat(allocator_, buf, 2);
|
auto v = Cat(allocator, buf, 2);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
buf[n] = &states[n][num_encoders * 5 + i];
|
buf[n] = &states[n][num_encoders * 5 + i];
|
||||||
}
|
}
|
||||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
|
|||||||
buf[n] = &states[n][num_encoders * 6 + i];
|
buf[n] = &states[n][num_encoders * 6 + i];
|
||||||
}
|
}
|
||||||
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
int32_t num_encoders = num_encoder_layers_.size();
|
int32_t num_encoders = num_encoder_layers_.size();
|
||||||
|
|
||||||
|
auto allocator =
|
||||||
|
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans;
|
std::vector<std::vector<Ort::Value>> ans;
|
||||||
ans.resize(batch_size);
|
ans.resize(batch_size);
|
||||||
|
|
||||||
// cached_len
|
// cached_len
|
||||||
for (int32_t i = 0; i != num_encoders; ++i) {
|
for (int32_t i = 0; i != num_encoders; ++i) {
|
||||||
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
|
auto v = Unbind<int64_t>(allocator, &states[i], 1);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_avg
|
// cached_avg
|
||||||
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
|
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 1);
|
auto v = Unbind(allocator, &states[i], 1);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_key
|
// cached_key
|
||||||
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
|
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 2);
|
auto v = Unbind(allocator, &states[i], 2);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_val
|
// cached_val
|
||||||
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
|
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 2);
|
auto v = Unbind(allocator, &states[i], 2);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_val2
|
// cached_val2
|
||||||
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
|
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 2);
|
auto v = Unbind(allocator, &states[i], 2);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_conv1
|
// cached_conv1
|
||||||
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
|
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 1);
|
auto v = Unbind(allocator, &states[i], 1);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
|
|||||||
|
|
||||||
// cached_conv2
|
// cached_conv2
|
||||||
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
|
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
|
||||||
auto v = Unbind(allocator_, &states[i], 1);
|
auto v = Unbind(allocator, &states[i], 1);
|
||||||
assert(v.size() == batch_size);
|
assert(v.size() == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
|
|
||||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
// Return a vector containing 3 tensors
|
// Return a vector containing 3 tensors
|
||||||
// - attn_cache
|
// - attn_cache
|
||||||
@@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ort::Value> StackStates(
|
std::vector<Ort::Value> StackStates(
|
||||||
std::vector<std::vector<Ort::Value>> states) const {
|
std::vector<std::vector<Ort::Value>> states) {
|
||||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||||
|
|
||||||
std::vector<const Ort::Value *> buf(batch_size);
|
std::vector<const Ort::Value *> buf(batch_size);
|
||||||
@@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||||
std::vector<Ort::Value> states) const {
|
std::vector<Ort::Value> states) {
|
||||||
int32_t m = std::accumulate(num_encoder_layers_.begin(),
|
int32_t m = std::accumulate(num_encoder_layers_.begin(),
|
||||||
num_encoder_layers_.end(), 0);
|
num_encoder_layers_.end(), 0);
|
||||||
assert(states.size() == m * 6 + 2);
|
assert(states.size() == m * 6 + 2);
|
||||||
|
|||||||
@@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
|||||||
|
|
||||||
std::vector<const Ort::Value *> buf(batch_size);
|
std::vector<const Ort::Value *> buf(batch_size);
|
||||||
|
|
||||||
|
auto allocator =
|
||||||
|
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
int32_t num_states = static_cast<int32_t>(states[0].size());
|
int32_t num_states = static_cast<int32_t>(states[0].size());
|
||||||
ans.reserve(num_states);
|
ans.reserve(num_states);
|
||||||
@@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
|||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i];
|
buf[n] = &states[n][6 * i];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i + 1];
|
buf[n] = &states[n][6 * i + 1];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i + 2];
|
buf[n] = &states[n][6 * i + 2];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i + 3];
|
buf[n] = &states[n][6 * i + 3];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 1);
|
auto v = Cat(allocator, buf, 1);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i + 4];
|
buf[n] = &states[n][6 * i + 4];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 0);
|
auto v = Cat(allocator, buf, 0);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][6 * i + 5];
|
buf[n] = &states[n][6 * i + 5];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 0);
|
auto v = Cat(allocator, buf, 0);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
|||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][num_states - 2];
|
buf[n] = &states[n][num_states - 2];
|
||||||
}
|
}
|
||||||
auto v = Cat(allocator_, buf, 0);
|
auto v = Cat(allocator, buf, 0);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
|
|||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
buf[n] = &states[n][num_states - 1];
|
buf[n] = &states[n][num_states - 1];
|
||||||
}
|
}
|
||||||
auto v = Cat<int64_t>(allocator_, buf, 0);
|
auto v = Cat<int64_t>(allocator, buf, 0);
|
||||||
ans.push_back(std::move(v));
|
ans.push_back(std::move(v));
|
||||||
}
|
}
|
||||||
return ans;
|
return ans;
|
||||||
@@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
|
|
||||||
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
auto allocator =
|
||||||
|
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
|
||||||
|
|
||||||
std::vector<std::vector<Ort::Value>> ans;
|
std::vector<std::vector<Ort::Value>> ans;
|
||||||
ans.resize(batch_size);
|
ans.resize(batch_size);
|
||||||
|
|
||||||
for (int32_t i = 0; i != m; ++i) {
|
for (int32_t i = 0; i != m; ++i) {
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6], 1);
|
auto v = Unbind(allocator, &states[i * 6], 1);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
|
auto v = Unbind(allocator, &states[i * 6 + 1], 1);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
|
auto v = Unbind(allocator, &states[i * 6 + 2], 1);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
|
auto v = Unbind(allocator, &states[i * 6 + 3], 1);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
|
auto v = Unbind(allocator, &states[i * 6 + 4], 0);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
|
auto v = Unbind(allocator, &states[i * 6 + 5], 0);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto v = Unbind(allocator_, &states[m * 6], 0);
|
auto v = Unbind(allocator, &states[m * 6], 0);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
@@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
|
auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
|
||||||
assert(static_cast<int32_t>(v.size()) == batch_size);
|
assert(static_cast<int32_t>(v.size()) == batch_size);
|
||||||
|
|
||||||
for (int32_t n = 0; n != batch_size; ++n) {
|
for (int32_t n = 0; n != batch_size; ++n) {
|
||||||
|
|||||||
@@ -21,6 +21,36 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static std::string GetInputName(Ort::Session *sess, size_t index,
|
||||||
|
OrtAllocator *allocator) {
|
||||||
|
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||||
|
// For other versions, we may need to change it
|
||||||
|
#if ORT_API_VERSION >= 17
|
||||||
|
auto v = sess->GetInputNameAllocated(index, allocator);
|
||||||
|
return v.get();
|
||||||
|
#else
|
||||||
|
auto v = sess->GetInputName(index, allocator);
|
||||||
|
std::string ans = v;
|
||||||
|
allocator->Free(allocator, v);
|
||||||
|
return ans;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string GetOutputName(Ort::Session *sess, size_t index,
|
||||||
|
OrtAllocator *allocator) {
|
||||||
|
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||||
|
// For other versions, we may need to change it
|
||||||
|
#if ORT_API_VERSION >= 17
|
||||||
|
auto v = sess->GetOutputNameAllocated(index, allocator);
|
||||||
|
return v.get();
|
||||||
|
#else
|
||||||
|
auto v = sess->GetOutputName(index, allocator);
|
||||||
|
std::string ans = v;
|
||||||
|
allocator->Free(allocator, v);
|
||||||
|
return ans;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
||||||
std::vector<const char *> *input_names_ptr) {
|
std::vector<const char *> *input_names_ptr) {
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
@@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
|||||||
input_names->resize(node_count);
|
input_names->resize(node_count);
|
||||||
input_names_ptr->resize(node_count);
|
input_names_ptr->resize(node_count);
|
||||||
for (size_t i = 0; i != node_count; ++i) {
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
(*input_names)[i] = GetInputName(sess, i, allocator);
|
||||||
(*input_names)[i] = tmp.get();
|
|
||||||
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
|||||||
output_names->resize(node_count);
|
output_names->resize(node_count);
|
||||||
output_names_ptr->resize(node_count);
|
output_names_ptr->resize(node_count);
|
||||||
for (size_t i = 0; i != node_count; ++i) {
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
(*output_names)[i] = GetOutputName(sess, i, allocator);
|
||||||
(*output_names)[i] = tmp.get();
|
|
||||||
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
|||||||
|
|
||||||
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
#if ORT_API_VERSION >= 17
|
||||||
std::vector<Ort::AllocatedStringPtr> v =
|
std::vector<Ort::AllocatedStringPtr> v =
|
||||||
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
|
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
|
||||||
for (const auto &key : v) {
|
for (const auto &key : v) {
|
||||||
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
|
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
|
||||||
os << key.get() << "=" << p.get() << "\n";
|
os << key.get() << "=" << p.get() << "\n";
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
int64_t num_keys = 0;
|
||||||
|
char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
|
||||||
|
for (int32_t i = 0; i < num_keys; ++i) {
|
||||||
|
auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
|
||||||
|
os << keys[i] << "=" << v << "\n";
|
||||||
|
allocator.Free(keys[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
allocator.Free(keys);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
||||||
@@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
|
||||||
|
const char *key,
|
||||||
|
OrtAllocator *allocator) {
|
||||||
|
// Note(fangjun): We only tested 1.17.1 and 1.11.0
|
||||||
|
// For other versions, we may need to change it
|
||||||
|
#if ORT_API_VERSION >= 17
|
||||||
|
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
|
||||||
|
return v.get();
|
||||||
|
#else
|
||||||
|
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
|
||||||
|
std::string ans = v;
|
||||||
|
allocator->Free(allocator, v);
|
||||||
|
return ans;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
|||||||
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
||||||
int32_t t);
|
int32_t t);
|
||||||
|
|
||||||
|
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
|
||||||
|
const char *key, OrtAllocator *allocator);
|
||||||
|
|
||||||
void PrintModelMetadata(std::ostream &os,
|
void PrintModelMetadata(std::ostream &os,
|
||||||
const Ort::ModelMetadata &meta_data); // NOLINT
|
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
|
|||||||
case Provider::kCPU:
|
case Provider::kCPU:
|
||||||
break; // nothing to do for the CPU provider
|
break; // nothing to do for the CPU provider
|
||||||
case Provider::kXnnpack: {
|
case Provider::kXnnpack: {
|
||||||
|
#if ORT_API_VERSION >= 17
|
||||||
if (std::find(available_providers.begin(), available_providers.end(),
|
if (std::find(available_providers.begin(), available_providers.end(),
|
||||||
"XnnpackExecutionProvider") != available_providers.end()) {
|
"XnnpackExecutionProvider") != available_providers.end()) {
|
||||||
sess_opts.AppendExecutionProvider("XNNPACK");
|
sess_opts.AppendExecutionProvider("XNNPACK");
|
||||||
@@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
|
|||||||
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
|
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
|
||||||
os.str().c_str());
|
os.str().c_str());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
|
||||||
|
static_cast<int32_t>(ORT_API_VERSION));
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Provider::kTRT: {
|
case Provider::kTRT: {
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto model_type =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
|
LookupCustomModelMetaData(meta_data, "framework", allocator);
|
||||||
if (!model_type) {
|
if (model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n"
|
"No model_type in the metadata!\n"
|
||||||
"Please make sure you have added metadata to the model.\n\n"
|
"Please make sure you have added metadata to the model.\n\n"
|
||||||
@@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("wespeaker")) {
|
if (model_type == "wespeaker") {
|
||||||
return ModelType::kWeSpeaker;
|
return ModelType::kWeSpeaker;
|
||||||
} else if (model_type.get() == std::string("3d-speaker")) {
|
} else if (model_type == "3d-speaker") {
|
||||||
return ModelType::k3dSpeaker;
|
return ModelType::k3dSpeaker;
|
||||||
} else if (model_type.get() == std::string("nemo")) {
|
} else if (model_type == "nemo") {
|
||||||
return ModelType::kNeMo;
|
return ModelType::kNeMo;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
|
|||||||
return std::move(outputs[0]);
|
return std::move(outputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *Allocator() const { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
|
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
|
||||||
return meta_data_;
|
return meta_data_;
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto model_type =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
LookupCustomModelMetaData(meta_data, "model_type", allocator);
|
||||||
if (!model_type) {
|
if (model_type.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"No model_type in the metadata!\n"
|
"No model_type in the metadata!\n"
|
||||||
"Please make sure you have added metadata to the model.\n\n"
|
"Please make sure you have added metadata to the model.\n\n"
|
||||||
@@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model_type_str = std::string(model_type.get());
|
if (model_type.find("whisper") == 0) {
|
||||||
if (model_type_str.find("whisper") == 0) {
|
|
||||||
return ModelType::kWhisper;
|
return ModelType::kWhisper;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
|
||||||
return ModelType::kUnknown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,20 +29,19 @@ namespace {
|
|||||||
const char *ws = " \t\n\r\f\v";
|
const char *ws = " \t\n\r\f\v";
|
||||||
|
|
||||||
// trim from end of string (right)
|
// trim from end of string (right)
|
||||||
inline std::string &TrimRight(std::string &s, const char *t = ws) {
|
inline void TrimRight(std::string *s, const char *t = ws) {
|
||||||
s.erase(s.find_last_not_of(t) + 1);
|
s->erase(s->find_last_not_of(t) + 1);
|
||||||
return s;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// trim from beginning of string (left)
|
// trim from beginning of string (left)
|
||||||
inline std::string &TrimLeft(std::string &s, const char *t = ws) {
|
inline void TrimLeft(std::string *s, const char *t = ws) {
|
||||||
s.erase(0, s.find_first_not_of(t));
|
s->erase(0, s->find_first_not_of(t));
|
||||||
return s;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// trim from both ends of string (right then left)
|
// trim from both ends of string (right then left)
|
||||||
inline std::string &Trim(std::string &s, const char *t = ws) {
|
inline void Trim(std::string *s, const char *t = ws) {
|
||||||
return TrimLeft(TrimRight(s, t), t);
|
TrimRight(s, t);
|
||||||
|
TrimLeft(s, t);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
|
|||||||
std::string sym;
|
std::string sym;
|
||||||
int32_t id = -1;
|
int32_t id = -1;
|
||||||
while (std::getline(is, line)) {
|
while (std::getline(is, line)) {
|
||||||
Trim(line);
|
Trim(&line);
|
||||||
std::istringstream iss(line);
|
std::istringstream iss(line);
|
||||||
iss >> sym;
|
iss >> sym;
|
||||||
if (iss.eof()) {
|
if (iss.eof()) {
|
||||||
|
|||||||
Reference in New Issue
Block a user