Support streaming zipformer CTC (#496)

* Support streaming zipformer CTC

* test online zipformer2 CTC

* Update doc of sherpa-onnx.cc

* Add Python APIs for streaming zipformer2 ctc

* Add Python API examples for streaming zipformer2 ctc

* Swift API for streaming zipformer2 CTC

* NodeJS API for streaming zipformer2 CTC

* Kotlin API for streaming zipformer2 CTC

* Golang API for streaming zipformer2 CTC

* C# API for streaming zipformer2 CTC

* Release v1.9.6
This commit is contained in:
Fangjun Kuang
2023-12-22 13:46:33 +08:00
committed by GitHub
parent 7634f5f034
commit e475e750ac
70 changed files with 1517 additions and 211 deletions

View File

@@ -51,6 +51,13 @@ rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
node ./test-online-transducer.js node ./test-online-transducer.js
rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
# offline tts # offline tts
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2

View File

@@ -13,6 +13,37 @@ echo "PATH: $PATH"
which $EXE which $EXE
log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC "
log "------------------------------------------------------------"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
repo=$(basename -s .tar.bz2 $url)
curl -SL -O $url
tar xvf $repo.tar.bz2
rm $repo.tar.bz2
log "test fp32"
time $EXE \
--debug=1 \
--zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
log "test int8"
time $EXE \
--debug=1 \
--zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
log "------------------------------------------------------------" log "------------------------------------------------------------"
log "Run streaming Conformer CTC from WeNet" log "Run streaming Conformer CTC from WeNet"
log "------------------------------------------------------------" log "------------------------------------------------------------"

View File

@@ -8,6 +8,27 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
} }
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
pushd $dir
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
popd
repo=$dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
python3 ./python-api-examples/online-decode-files.py \
--tokens=$repo/tokens.txt \
--zipformer2-ctc=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
wenet_models=( wenet_models=(
sherpa-onnx-zh-wenet-aishell sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2 sherpa-onnx-zh-wenet-aishell2
@@ -17,8 +38,6 @@ sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech sherpa-onnx-en-wenet-gigaspeech
) )
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
for name in ${wenet_models[@]}; do for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name repo_url=https://huggingface.co/csukuangfj/$name

View File

@@ -21,6 +21,9 @@ cat /Users/fangjun/Desktop/Obama.srt
./run-tts.sh ./run-tts.sh
ls -lh ls -lh
./run-decode-file.sh
rm decode-file
sed -i.bak '20d' ./decode-file.swift
./run-decode-file.sh ./run-decode-file.sh
./run-decode-file-non-streaming.sh ./run-decode-file-non-streaming.sh

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -107,6 +107,14 @@ jobs:
name: release-static name: release-static
path: build/bin/* path: build/bin/*
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline Whisper - name: Test offline Whisper
shell: bash shell: bash
run: | run: |
@@ -117,14 +125,6 @@ jobs:
.github/scripts/test-offline-whisper.sh .github/scripts/test-offline-whisper.sh
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline CTC - name: Test offline CTC
shell: bash shell: bash
run: | run: |

View File

@@ -25,7 +25,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -55,7 +55,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }} key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -49,7 +49,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -29,7 +29,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -61,7 +61,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest] #, windows-latest]
python-version: ["3.8"] python-version: ["3.8"]
steps: steps:
@@ -70,7 +70,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@@ -143,6 +143,7 @@ jobs:
cd dotnet-examples/ cd dotnet-examples/
cd online-decode-files cd online-decode-files
./run-zipformer2-ctc.sh
./run-transducer.sh ./run-transducer.sh
./run-paraformer.sh ./run-paraformer.sh

View File

@@ -53,7 +53,7 @@ jobs:
mkdir build mkdir build
cd build cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF .. cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF ..
make -j make -j1
cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/ cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/
cd ../scripts/go/_internal/ cd ../scripts/go/_internal/
@@ -153,6 +153,14 @@ jobs:
git lfs install git lfs install
echo "Test zipformer2 CTC"
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./run-zipformer2-ctc.sh
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
echo "Test transducer" echo "Test transducer"
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
./run-transducer.sh ./run-transducer.sh

View File

@@ -34,7 +34,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -52,7 +52,7 @@ jobs:
ls -lh install/lib ls -lh install/lib
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -40,7 +40,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -38,7 +38,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }} key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@@ -25,7 +25,7 @@ jobs:
matrix: matrix:
os: [ubuntu-latest, windows-latest, macos-latest] os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer", "paraformer"] model_type: ["transducer", "paraformer", "zipformer2-ctc"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -38,7 +38,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }} key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@@ -57,6 +57,26 @@ jobs:
python3 -m pip install --no-deps --verbose . python3 -m pip install --no-deps --verbose .
python3 -m pip install websockets python3 -m pip install websockets
- name: Start server for zipformer2 CTC models
if: matrix.model_type == 'zipformer2-ctc'
shell: bash
run: |
curl -O -L https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
python3 ./python-api-examples/streaming_server.py \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for zipformer2 CTC models
if: matrix.model_type == 'zipformer2-ctc'
shell: bash
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav
- name: Start server for transducer models - name: Start server for transducer models
if: matrix.model_type == 'transducer' if: matrix.model_type == 'transducer'

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR) cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx) project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.4") set(SHERPA_ONNX_VERSION "1.9.6")
# Disable warning about # Disable warning about
# #

View File

@@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig(
var decoder: String = "", var decoder: String = "",
) )
data class OnlineZipformer2CtcModelConfig(
var model: String = "",
)
data class OnlineModelConfig( data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),
var tokens: String, var tokens: String,
var numThreads: Int = 1, var numThreads: Int = 1,
var debug: Boolean = false, var debug: Boolean = false,

View File

@@ -1,2 +1,3 @@
bin bin
obj obj
!*.sh

View File

@@ -38,6 +38,9 @@ class OnlineDecodeFiles
[Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")] [Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")]
public string ParaformerDecoder { get; set; } public string ParaformerDecoder { get; set; }
[Option("zipformer2-ctc", Required = false, HelpText = "Path to zipformer2 CTC onnx model")]
public string Zipformer2Ctc { get; set; }
[Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")] [Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")]
public int NumThreads { get; set; } public int NumThreads { get; set; }
@@ -107,7 +110,19 @@ dotnet run \
--files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ --files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav
(2) Streaming Paraformer models (2) Streaming Zipformer2 Ctc models
dotnet run -c Release \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav
(3) Streaming Paraformer models
dotnet run \ dotnet run \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
@@ -121,6 +136,7 @@ dotnet run \
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html
to download pre-trained streaming models. to download pre-trained streaming models.
"; ";
@@ -150,6 +166,8 @@ to download pre-trained streaming models.
config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder; config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder;
config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder; config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder;
config.ModelConfig.Zipformer2Ctc.Model = options.Zipformer2Ctc;
config.ModelConfig.Tokens = options.Tokens; config.ModelConfig.Tokens = options.Tokens;
config.ModelConfig.Provider = options.Provider; config.ModelConfig.Provider = options.Provider;
config.ModelConfig.NumThreads = options.NumThreads; config.ModelConfig.NumThreads = options.NumThreads;

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese
# to download the model files
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
dotnet run -c Release \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav

1
go-api-examples/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
!*.sh

View File

@@ -22,6 +22,7 @@ func main() {
flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model")
flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model")
flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model")
flag.StringVar(&config.ModelConfig.Zipformer2Ctc.Model, "zipformer2-ctc", "", "Path to the zipformer2 CTC model")
flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file")
flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing")
flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message")

View File

@@ -0,0 +1,13 @@
#!/usr/bin/env bash
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese
# to download the model
# before you run this script.
#
# You can switch to a different online model if you need
./streaming-decode-files \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav

View File

@@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit {
fun main() { fun main() {
testTts() testTts()
testAsr() testAsr("transducer")
testAsr("zipformer2-ctc")
} }
fun testTts() { fun testTts() {
@@ -30,25 +31,43 @@ fun testTts() {
audio.save(filename="test-en.wav") audio.save(filename="test-en.wav")
} }
fun testAsr() { fun testAsr(type: String) {
var featConfig = FeatureConfig( var featConfig = FeatureConfig(
sampleRate = 16000, sampleRate = 16000,
featureDim = 80, featureDim = 80,
) )
// please refer to var waveFilename: String
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html var modelConfig: OnlineModelConfig = when (type) {
// to dowload pre-trained models "transducer" -> {
var modelConfig = OnlineModelConfig( waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
transducer = OnlineTransducerModelConfig( // please refer to
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", // to dowload pre-trained models
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", OnlineModelConfig(
), transducer = OnlineTransducerModelConfig(
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
numThreads = 1, decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
debug = false, joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
) ),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
var endpointConfig = EndpointConfig() var endpointConfig = EndpointConfig()
@@ -69,7 +88,7 @@ fun testAsr() {
) )
var objArray = WaveReader.readWaveFromFile( var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", filename = waveFilename,
) )
var samples: FloatArray = objArray[0] as FloatArray var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int var sampleRate: Int = objArray[1] as Int

View File

@@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi fi
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2

View File

@@ -85,7 +85,7 @@ npm install wav naudiodon2
how to decode a file with a NeMo CTC model. In the code we use how to decode a file with a NeMo CTC model. In the code we use
[stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small). [stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2
@@ -99,7 +99,7 @@ node ./test-offline-nemo-ctc.js
how to decode a file with a non-streaming Paraformer model. In the code we use how to decode a file with a non-streaming Paraformer model. In the code we use
[sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese). [sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
@@ -113,7 +113,7 @@ node ./test-offline-paraformer.js
how to decode a file with a non-streaming transducer model. In the code we use how to decode a file with a non-streaming transducer model. In the code we use
[sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english). [sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2
@@ -126,7 +126,7 @@ node ./test-offline-transducer.js
how to decode a file with a Whisper model. In the code we use how to decode a file with a Whisper model. In the code we use
[sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html). [sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
@@ -140,7 +140,7 @@ demonstrates how to do real-time speech recognition from microphone
with a streaming Paraformer model. In the code we use with a streaming Paraformer model. In the code we use
[sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
@@ -153,7 +153,7 @@ node ./test-online-paraformer-microphone.js
how to decode a file using a streaming Paraformer model. In the code we use how to decode a file using a streaming Paraformer model. In the code we use
[sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
@@ -167,7 +167,7 @@ demonstrates how to do real-time speech recognition with microphone using a stre
we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
@@ -180,7 +180,7 @@ node ./test-online-transducer-microphone.js
how to decode a file using a streaming transducer model. In the code how to decode a file using a streaming transducer model. In the code
we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
@@ -188,13 +188,26 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
node ./test-online-transducer.js node ./test-online-transducer.js
``` ```
## ./test-online-zipformer2-ctc.js
[./test-online-zipformer2-ctc.js](./test-online-zipformer2-ctc.js) demonstrates
how to decode a file using a streaming zipformer2 CTC model. In the code
we use [sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese).
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
```
## ./test-vad-microphone-offline-paraformer.js ## ./test-vad-microphone-offline-paraformer.js
[./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js) [./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js)
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with non-streaming Paraformer for speech recognition from microphone. with non-streaming Paraformer for speech recognition from microphone.
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
@@ -209,7 +222,7 @@ node ./test-vad-microphone-offline-paraformer.js
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with a non-streaming transducer model for speech recognition from microphone. with a non-streaming transducer model for speech recognition from microphone.
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
@@ -224,7 +237,7 @@ node ./test-vad-microphone-offline-transducer.js
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with whisper for speech recognition from microphone. with whisper for speech recognition from microphone.
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
@@ -238,7 +251,7 @@ node ./test-vad-microphone-offline-whisper.js
[./test-vad-microphone.js](./test-vad-microphone.js) [./test-vad-microphone.js](./test-vad-microphone.js)
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad). demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad).
You can use the following command run it: You can use the following command to run it:
```bash ```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx

View File

@@ -0,0 +1,97 @@
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
//
const fs = require('fs');
const {Readable} = require('stream');
const wav = require('wav');
const sherpa_onnx = require('sherpa-onnx');
function createRecognizer() {
const featConfig = new sherpa_onnx.FeatureConfig();
featConfig.sampleRate = 16000;
featConfig.featureDim = 80;
// test online recognizer
const zipformer2Ctc = new sherpa_onnx.OnlineZipformer2CtcModelConfig();
zipformer2Ctc.model =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx';
const tokens =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt';
const modelConfig = new sherpa_onnx.OnlineModelConfig();
modelConfig.zipformer2Ctc = zipformer2Ctc;
modelConfig.tokens = tokens;
const recognizerConfig = new sherpa_onnx.OnlineRecognizerConfig();
recognizerConfig.featConfig = featConfig;
recognizerConfig.modelConfig = modelConfig;
recognizerConfig.decodingMethod = 'greedy_search';
recognizer = new sherpa_onnx.OnlineRecognizer(recognizerConfig);
return recognizer;
}
recognizer = createRecognizer();
stream = recognizer.createStream();
const waveFilename =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav';
const reader = new wav.Reader();
const readable = new Readable().wrap(reader);
function decode(samples) {
stream.acceptWaveform(recognizer.config.featConfig.sampleRate, samples);
while (recognizer.isReady(stream)) {
recognizer.decode(stream);
}
const r = recognizer.getResult(stream);
console.log(r.text);
}
reader.on('format', ({audioFormat, bitDepth, channels, sampleRate}) => {
if (sampleRate != recognizer.config.featConfig.sampleRate) {
throw new Error(`Only support sampleRate ${
recognizer.config.featConfig.sampleRate}. Given ${sampleRate}`);
}
if (audioFormat != 1) {
throw new Error(`Only support PCM format. Given ${audioFormat}`);
}
if (channels != 1) {
throw new Error(`Only a single channel. Given ${channel}`);
}
if (bitDepth != 16) {
throw new Error(`Only support 16-bit samples. Given ${bitDepth}`);
}
});
fs.createReadStream(waveFilename, {'highWaterMark': 4096})
.pipe(reader)
.on('finish', function(err) {
// tail padding
const floatSamples =
new Float32Array(recognizer.config.featConfig.sampleRate * 0.5);
decode(floatSamples);
stream.free();
recognizer.free();
});
readable.on('readable', function() {
let chunk;
while ((chunk = readable.read()) != null) {
const int16Samples = new Int16Array(
chunk.buffer, chunk.byteOffset,
chunk.length / Int16Array.BYTES_PER_ELEMENT);
const floatSamples = new Float32Array(int16Samples.length);
for (let i = 0; i < floatSamples.length; i++) {
floatSamples[i] = int16Samples[i] / 32768.0;
}
decode(floatSamples);
}
});

View File

@@ -37,7 +37,20 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
(3) Streaming Conformer CTC from WeNet (3) Streaming Zipformer2 CTC
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
./python-api-examples/online-decode-files.py \
--zipformer2-ctc=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav
(4) Streaming Conformer CTC from WeNet
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech cd sherpa-onnx-zh-wenet-wenetspeech
@@ -51,12 +64,9 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
and to download streaming pre-trained models.
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to install sherpa-onnx and to download streaming pre-trained models.
""" """
import argparse import argparse
import time import time
@@ -97,6 +107,12 @@ def get_args():
help="Path to the transducer joiner model", help="Path to the transducer joiner model",
) )
parser.add_argument(
"--zipformer2-ctc",
type=str,
help="Path to the zipformer2 ctc model",
)
parser.add_argument( parser.add_argument(
"--paraformer-encoder", "--paraformer-encoder",
type=str, type=str,
@@ -112,7 +128,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--wenet-ctc", "--wenet-ctc",
type=str, type=str,
help="Path to the wenet ctc model model", help="Path to the wenet ctc model",
) )
parser.add_argument( parser.add_argument(
@@ -275,6 +291,16 @@ def main():
hotwords_file=args.hotwords_file, hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score, hotwords_score=args.hotwords_score,
) )
elif args.zipformer2_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.zipformer2_ctc,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
elif args.paraformer_encoder: elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=args.tokens, tokens=args.tokens,

View File

@@ -25,6 +25,7 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
import argparse import argparse
import asyncio import asyncio
import json
import logging import logging
import wave import wave
@@ -112,7 +113,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
async for message in socket: async for message in socket:
if message != "Done!": if message != "Done!":
last_message = message last_message = message
logging.info(message) logging.info(json.loads(message))
else: else:
break break
return last_message return last_message
@@ -151,7 +152,7 @@ async def run(
await websocket.send("Done") await websocket.send("Done")
decoding_results = await receive_task decoding_results = await receive_task
logging.info(f"\nFinal result is:\n{decoding_results}") logging.info(f"\nFinal result is:\n{json.loads(decoding_results)}")
async def main(): async def main():

View File

@@ -137,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser):
help="Path to the transducer joiner model.", help="Path to the transducer joiner model.",
) )
parser.add_argument(
"--zipformer2-ctc",
type=str,
help="Path to the model file from zipformer2 ctc",
)
parser.add_argument( parser.add_argument(
"--wenet-ctc", "--wenet-ctc",
type=str, type=str,
@@ -405,6 +411,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length, rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider, provider=args.provider,
) )
elif args.zipformer2_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.zipformer2_ctc,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.wenet_ctc: elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens, tokens=args.tokens,
@@ -748,6 +768,8 @@ def check_args(args):
assert args.paraformer_encoder is None, args.paraformer_encoder assert args.paraformer_encoder is None, args.paraformer_encoder
assert args.paraformer_decoder is None, args.paraformer_decoder assert args.paraformer_decoder is None, args.paraformer_decoder
assert args.zipformer2_ctc is None, args.zipformer2_ctc
assert args.wenet_ctc is None, args.wenet_ctc
elif args.paraformer_encoder: elif args.paraformer_encoder:
assert Path( assert Path(
args.paraformer_encoder args.paraformer_encoder
@@ -756,6 +778,10 @@ def check_args(args):
assert Path( assert Path(
args.paraformer_decoder args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist" ).is_file(), f"{args.paraformer_decoder} does not exist"
elif args.zipformer2_ctc:
assert Path(
args.zipformer2_ctc
).is_file(), f"{args.zipformer2_ctc} does not exist"
elif args.wenet_ctc: elif args.wenet_ctc:
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
else: else:

View File

@@ -50,6 +50,18 @@ namespace SherpaOnnx
public string Decoder; public string Decoder;
} }
[StructLayout(LayoutKind.Sequential)]
public struct OnlineZipformer2CtcModelConfig
{
public OnlineZipformer2CtcModelConfig()
{
Model = "";
}
[MarshalAs(UnmanagedType.LPStr)]
public string Model;
}
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct OnlineModelConfig public struct OnlineModelConfig
{ {
@@ -57,6 +69,7 @@ namespace SherpaOnnx
{ {
Transducer = new OnlineTransducerModelConfig(); Transducer = new OnlineTransducerModelConfig();
Paraformer = new OnlineParaformerModelConfig(); Paraformer = new OnlineParaformerModelConfig();
Zipformer2Ctc = new OnlineZipformer2CtcModelConfig();
Tokens = ""; Tokens = "";
NumThreads = 1; NumThreads = 1;
Provider = "cpu"; Provider = "cpu";
@@ -66,6 +79,7 @@ namespace SherpaOnnx
public OnlineTransducerModelConfig Transducer; public OnlineTransducerModelConfig Transducer;
public OnlineParaformerModelConfig Paraformer; public OnlineParaformerModelConfig Paraformer;
public OnlineZipformer2CtcModelConfig Zipformer2Ctc;
[MarshalAs(UnmanagedType.LPStr)] [MarshalAs(UnmanagedType.LPStr)]
public string Tokens; public string Tokens;

View File

@@ -0,0 +1 @@
../../../../go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh

View File

@@ -65,6 +65,13 @@ type OnlineParaformerModelConfig struct {
Decoder string // Path to the decoder model. Decoder string // Path to the decoder model.
} }
// Please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html
// to download pre-trained models
type OnlineZipformer2CtcModelConfig struct {
Model string // Path to the onnx model
}
// Configuration for online/streaming models // Configuration for online/streaming models
// //
// Please refer to // Please refer to
@@ -72,13 +79,14 @@ type OnlineParaformerModelConfig struct {
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html
// to download pre-trained models // to download pre-trained models
type OnlineModelConfig struct { type OnlineModelConfig struct {
Transducer OnlineTransducerModelConfig Transducer OnlineTransducerModelConfig
Paraformer OnlineParaformerModelConfig Paraformer OnlineParaformerModelConfig
Tokens string // Path to tokens.txt Zipformer2Ctc OnlineZipformer2CtcModelConfig
NumThreads int // Number of threads to use for neural network computation Tokens string // Path to tokens.txt
Provider string // Optional. Valid values are: cpu, cuda, coreml NumThreads int // Number of threads to use for neural network computation
Debug int // 1 to show model meta information while loading it. Provider string // Optional. Valid values are: cpu, cuda, coreml
ModelType string // Optional. You can specify it for faster model initialization Debug int // 1 to show model meta information while loading it.
ModelType string // Optional. You can specify it for faster model initialization
} }
// Configuration for the feature extractor // Configuration for the feature extractor
@@ -157,6 +165,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer {
c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder)
defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder))
c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model)
defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model))
c.model_config.tokens = C.CString(config.ModelConfig.Tokens) c.model_config.tokens = C.CString(config.ModelConfig.Tokens)
defer C.free(unsafe.Pointer(c.model_config.tokens)) defer C.free(unsafe.Pointer(c.model_config.tokens))

View File

@@ -41,9 +41,14 @@ const SherpaOnnxOnlineParaformerModelConfig = StructType({
"decoder" : cstring, "decoder" : cstring,
}); });
const SherpaOnnxOnlineZipformer2CtcModelConfig = StructType({
"model" : cstring,
});
const SherpaOnnxOnlineModelConfig = StructType({ const SherpaOnnxOnlineModelConfig = StructType({
"transducer" : SherpaOnnxOnlineTransducerModelConfig, "transducer" : SherpaOnnxOnlineTransducerModelConfig,
"paraformer" : SherpaOnnxOnlineParaformerModelConfig, "paraformer" : SherpaOnnxOnlineParaformerModelConfig,
"zipformer2Ctc" : SherpaOnnxOnlineZipformer2CtcModelConfig,
"tokens" : cstring, "tokens" : cstring,
"numThreads" : int32_t, "numThreads" : int32_t,
"provider" : cstring, "provider" : cstring,
@@ -663,6 +668,7 @@ const OnlineModelConfig = SherpaOnnxOnlineModelConfig;
const FeatureConfig = SherpaOnnxFeatureConfig; const FeatureConfig = SherpaOnnxFeatureConfig;
const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig; const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig;
const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig; const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig;
const OnlineZipformer2CtcModelConfig = SherpaOnnxOnlineZipformer2CtcModelConfig;
// offline asr // offline asr
const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig; const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig;
@@ -692,6 +698,7 @@ module.exports = {
OnlineRecognizer, OnlineRecognizer,
OnlineStream, OnlineStream,
OnlineParaformerModelConfig, OnlineParaformerModelConfig,
OnlineZipformer2CtcModelConfig,
// offline asr // offline asr
OfflineRecognizer, OfflineRecognizer,

View File

@@ -54,6 +54,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.model_config.paraformer.decoder = recognizer_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
recognizer_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
recognizer_config.model_config.tokens = recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, ""); SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads = recognizer_config.model_config.num_threads =

View File

@@ -66,9 +66,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig {
const char *decoder; const char *decoder;
} SherpaOnnxOnlineParaformerModelConfig; } SherpaOnnxOnlineParaformerModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxModelConfig { // Please visit
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#
// to download pre-trained streaming zipformer2 ctc models
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig {
const char *model;
} SherpaOnnxOnlineZipformer2CtcModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
SherpaOnnxOnlineTransducerModelConfig transducer; SherpaOnnxOnlineTransducerModelConfig transducer;
SherpaOnnxOnlineParaformerModelConfig paraformer; SherpaOnnxOnlineParaformerModelConfig paraformer;
SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc;
const char *tokens; const char *tokens;
int32_t num_threads; int32_t num_threads;
const char *provider; const char *provider;

View File

@@ -70,6 +70,8 @@ set(sources
online-wenet-ctc-model-config.cc online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc online-zipformer-transducer-model.cc
online-zipformer2-ctc-model-config.cc
online-zipformer2-ctc-model.cc
online-zipformer2-transducer-model.cc online-zipformer2-transducer-model.cc
onnx-utils.cc onnx-utils.cc
packed-sequence.cc packed-sequence.cc

View File

@@ -12,6 +12,9 @@
namespace sherpa_onnx { namespace sherpa_onnx {
struct OnlineCtcDecoderResult { struct OnlineCtcDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs /// The decoded token IDs
std::vector<int64_t> tokens; std::vector<int64_t> tokens;

View File

@@ -49,12 +49,17 @@ void OnlineCtcGreedySearchDecoder::Decode(
if (y != blank_id_ && y != prev_id) { if (y != blank_id_ && y != prev_id) {
r.tokens.push_back(y); r.tokens.push_back(y);
r.timestamps.push_back(t); r.timestamps.push_back(t + r.frame_offset);
} }
prev_id = y; prev_id = y;
} // for (int32_t t = 0; t != num_frames; ++t) { } // for (int32_t t = 0; t != num_frames; ++t) {
} // for (int32_t b = 0; b != batch_size; ++b) } // for (int32_t b = 0; b != batch_size; ++b)
// Update frame_offset
for (auto &r : *results) {
r.frame_offset += num_frames;
}
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -11,127 +11,35 @@
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
namespace {
enum class ModelType {
kZipformerCtc,
kWenetCtc,
kUnkown,
};
} // namespace
namespace sherpa_onnx { namespace sherpa_onnx {
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("zipformer2")) {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
const OnlineModelConfig &config) { const OnlineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.wenet_ctc.model.empty()) { if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model; return std::make_unique<OnlineWenetCtcModel>(config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else { } else {
SHERPA_ONNX_LOGE("Please specify a CTC model"); SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1); exit(-1);
} }
{
auto buffer = ReadFile(filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kZipformerCtc:
return nullptr;
// return std::make_unique<OnlineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OnlineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
return nullptr;
}
return nullptr;
} }
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
AAssetManager *mgr, const OnlineModelConfig &config) { AAssetManager *mgr, const OnlineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.wenet_ctc.model.empty()) { if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model; return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else { } else {
SHERPA_ONNX_LOGE("Please specify a CTC model"); SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1); exit(-1);
} }
{
auto buffer = ReadFile(mgr, filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kZipformerCtc:
return nullptr;
// return std::make_unique<OnlineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
return nullptr;
}
return nullptr;
} }
#endif #endif

View File

@@ -33,6 +33,26 @@ class OnlineCtcModel {
// Return a list of tensors containing the initial states // Return a list of tensors containing the initial states
virtual std::vector<Ort::Value> GetInitStates() const = 0; virtual std::vector<Ort::Value> GetInitStates() const = 0;
/** Stack a list of individual states into a batch.
*
* It is the inverse operation of `UnStackStates`.
*
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const = 0;
/** Unstack a batch state into a list of individual states.
*
* It is the inverse operation of `StackStates`.
*
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const = 0;
/** /**
* *
* @param x A 3-D tensor of shape (N, T, C). N has to be 1. * @param x A 3-D tensor of shape (N, T, C). N has to be 1.
@@ -60,6 +80,9 @@ class OnlineCtcModel {
// ChunkLength() frames, we advance by ChunkShift() frames // ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk. // before we process the next chunk.
virtual int32_t ChunkShift() const = 0; virtual int32_t ChunkShift() const = 0;
// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -14,6 +14,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po); transducer.Register(po);
paraformer.Register(po); paraformer.Register(po);
wenet_ctc.Register(po); wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt"); po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -26,10 +27,11 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider, po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml"); "Specify a provider to use: cpu, cuda, coreml");
po->Register("model-type", &model_type, po->Register(
"Specify it to reduce model initialization time. " "model-type", &model_type,
"Valid values are: conformer, lstm, zipformer, zipformer2." "Specify it to reduce model initialization time. "
"All other values lead to loading the model twice."); "Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc"
"All other values lead to loading the model twice.");
} }
bool OnlineModelConfig::Validate() const { bool OnlineModelConfig::Validate() const {
@@ -51,6 +53,10 @@ bool OnlineModelConfig::Validate() const {
return wenet_ctc.Validate(); return wenet_ctc.Validate();
} }
if (!zipformer2_ctc.model.empty()) {
return zipformer2_ctc.Validate();
}
return transducer.Validate(); return transducer.Validate();
} }
@@ -61,6 +67,7 @@ std::string OnlineModelConfig::ToString() const {
os << "transducer=" << transducer.ToString() << ", "; os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", "; os << "paraformer=" << paraformer.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", "; os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", "; os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", "; os << "debug=" << (debug ? "True" : "False") << ", ";

View File

@@ -9,6 +9,7 @@
#include "sherpa-onnx/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -16,6 +17,7 @@ struct OnlineModelConfig {
OnlineTransducerModelConfig transducer; OnlineTransducerModelConfig transducer;
OnlineParaformerModelConfig paraformer; OnlineParaformerModelConfig paraformer;
OnlineWenetCtcModelConfig wenet_ctc; OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
std::string tokens; std::string tokens;
int32_t num_threads = 1; int32_t num_threads = 1;
bool debug = false; bool debug = false;
@@ -25,7 +27,8 @@ struct OnlineModelConfig {
// - conformer, conformer transducer from icefall // - conformer, conformer transducer from icefall
// - lstm, lstm transducer from icefall // - lstm, lstm transducer from icefall
// - zipformer, zipformer transducer from icefall // - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer from icefall // - zipformer2, zipformer2 transducer or CTC from icefall
// - wenet_ctc, wenet CTC model
// //
// All other values are invalid and lead to loading the model twice. // All other values are invalid and lead to loading the model twice.
std::string model_type; std::string model_type;
@@ -34,11 +37,13 @@ struct OnlineModelConfig {
OnlineModelConfig(const OnlineTransducerModelConfig &transducer, OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer, const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const std::string &tokens, int32_t num_threads, bool debug, const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type) const std::string &provider, const std::string &model_type)
: transducer(transducer), : transducer(transducer),
paraformer(paraformer), paraformer(paraformer),
wenet_ctc(wenet_ctc), wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
tokens(tokens), tokens(tokens),
num_threads(num_threads), num_threads(num_threads),
debug(debug), debug(debug),

View File

@@ -96,8 +96,67 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
} }
void DecodeStreams(OnlineStream **ss, int32_t n) const override { void DecodeStreams(OnlineStream **ss, int32_t n) const override {
if (n == 1 || !model_->SupportBatchProcessing()) {
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}
// batch processing
int32_t chunk_length = model_->ChunkLength();
int32_t chunk_shift = model_->ChunkShift();
int32_t feat_dim = ss[0]->FeatureDim();
std::vector<OnlineCtcDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_length * feat_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
for (int32_t i = 0; i != n; ++i) { for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]); const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_length);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_length * feat_dim);
results[i] = std::move(ss[i]->GetCtcResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{n, chunk_length, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
auto states = model_->StackStates(std::move(states_vec));
int32_t num_states = states.size();
auto out = model_->Forward(std::move(x), std::move(states));
std::vector<Ort::Value> out_states;
out_states.reserve(num_states);
for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(out[k]));
}
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
ss[k]->SetStates(std::move(next_states[k]));
} }
} }

View File

@@ -20,7 +20,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerParaformerImpl>(config); return std::make_unique<OnlineRecognizerParaformerImpl>(config);
} }
if (!config.model_config.wenet_ctc.model.empty()) { if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(config); return std::make_unique<OnlineRecognizerCtcImpl>(config);
} }
@@ -39,7 +40,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
} }
if (!config.model_config.wenet_ctc.model.empty()) { if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
} }

View File

@@ -1,4 +1,4 @@
// sherpa-onnx/csrc/online-paraformer-model.cc // sherpa-onnx/csrc/online-wenet-ctc-model.cc
// //
// Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Xiaomi Corporation
@@ -239,4 +239,21 @@ std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const {
return impl_->GetInitStates(); return impl_->GetInitStates();
} }
std::vector<Ort::Value> OnlineWenetCtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
if (states.size() != 1) {
SHERPA_ONNX_LOGE("wenet CTC model supports only batch_size==1. Given: %d",
static_cast<int32_t>(states.size()));
}
return std::move(states[0]);
}
std::vector<std::vector<Ort::Value>> OnlineWenetCtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<std::vector<Ort::Value>> ans(1);
ans[0] = std::move(states);
return ans;
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -35,6 +35,12 @@ class OnlineWenetCtcModel : public OnlineCtcModel {
// - offset // - offset
std::vector<Ort::Value> GetInitStates() const override; std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/** /**
* *
* @param x A 3-D tensor of shape (N, T, C). N has to be 1. * @param x A 3-D tensor of shape (N, T, C). N has to be 1.
@@ -63,6 +69,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel {
// before we process the next chunk. // before we process the next chunk.
int32_t ChunkShift() const override; int32_t ChunkShift() const override;
bool SupportBatchProcessing() const override { return false; }
private: private:
class Impl; class Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;

View File

@@ -0,0 +1,41 @@
// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineZipformer2CtcModelConfig::Register(ParseOptions *po) {
po->Register("zipformer2-ctc-model", &model,
"Path to CTC model.onnx. See also "
"https://github.com/k2-fsa/icefall/pull/1413");
}
bool OnlineZipformer2CtcModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("--zipformer2-ctc-model is empty!");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineZipformer2CtcModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineZipformer2CtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineZipformer2CtcModelConfig {
std::string model;
OnlineZipformer2CtcModelConfig() = default;
explicit OnlineZipformer2CtcModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_

View File

@@ -0,0 +1,464 @@
// sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cmath>
#include <numeric>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineZipformer2CtcModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> inputs;
inputs.reserve(1 + states.size());
inputs.push_back(std::move(features));
for (auto &v : states) {
inputs.push_back(std::move(v));
}
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ChunkLength() const { return T_; }
int32_t ChunkShift() const { return decode_chunk_len_; }
OrtAllocator *Allocator() const { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache
// - conv_cache
// - offset
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(initial_states_.size());
for (auto &s : initial_states_) {
ans.push_back(View(&s));
}
return ans;
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
std::vector<const Ort::Value *> buf(batch_size);
std::vector<Ort::Value> ans;
int32_t num_states = static_cast<int32_t>(states[0].size());
ans.reserve(num_states);
for (int32_t i = 0; i != (num_states - 2) / 6; ++i) {
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 1];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 2];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 3];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 4];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 5];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 2];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 1];
}
auto v = Cat<int64_t>(allocator_, buf, 0);
ans.push_back(std::move(v));
}
return ans;
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2);
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
for (int32_t i = 0; i != m; ++i) {
{
auto v = Unbind(allocator_, &states[i * 6], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
}
{
auto v = Unbind(allocator_, &states[m * 6], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
return ans;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---zipformer2_ctc---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims");
SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims");
SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads");
SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
vocab_size_ = shape[2];
}
if (config_.debug) {
auto print = [](const std::vector<int32_t> &v, const char *name) {
fprintf(stderr, "%s: ", name);
for (auto i : v) {
fprintf(stderr, "%d ", i);
}
fprintf(stderr, "\n");
};
print(encoder_dims_, "encoder_dims");
print(query_head_dims_, "query_head_dims");
print(value_head_dims_, "value_head_dims");
print(num_heads_, "num_heads");
print(num_encoder_layers_, "num_encoder_layers");
print(cnn_module_kernels_, "cnn_module_kernels");
print(left_context_len_, "left_context_len");
SHERPA_ONNX_LOGE("T: %d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
SHERPA_ONNX_LOGE("vocab_size_: %d", vocab_size_);
}
InitStates();
}
void InitStates() {
int32_t n = static_cast<int32_t>(encoder_dims_.size());
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
initial_states_.reserve(m * 6 + 2);
for (int32_t i = 0; i != n; ++i) {
int32_t num_layers = num_encoder_layers_[i];
int32_t key_dim = query_head_dims_[i] * num_heads_[i];
int32_t value_dim = value_head_dims_[i] * num_heads_[i];
int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4;
for (int32_t j = 0; j != num_layers; ++j) {
{
std::array<int64_t, 3> s{left_context_len_[i], 1, key_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{1, 1, left_context_len_[i],
nonlin_attn_head_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
}
}
{
std::array<int64_t, 4> s{1, 128, 3, 19};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 1> s{1};
auto v =
Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
Fill<int64_t>(&v, 0);
initial_states_.push_back(std::move(v));
}
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
std::vector<Ort::Value> initial_states_;
std::vector<int32_t> encoder_dims_;
std::vector<int32_t> query_head_dims_;
std::vector<int32_t> value_head_dims_;
std::vector<int32_t> num_heads_;
std::vector<int32_t> num_encoder_layers_;
std::vector<int32_t> cnn_module_kernels_;
std::vector<int32_t> left_context_len_;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t vocab_size_ = 0;
};
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
AAssetManager *mgr, const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default;
std::vector<Ort::Value> OnlineZipformer2CtcModel::Forward(
Ort::Value x, std::vector<Ort::Value> states) const {
return impl_->Forward(std::move(x), std::move(states));
}
int32_t OnlineZipformer2CtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OnlineZipformer2CtcModel::ChunkLength() const {
return impl_->ChunkLength();
}
int32_t OnlineZipformer2CtcModel::ChunkShift() const {
return impl_->ChunkShift();
}
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
return impl_->Allocator();
}
std::vector<Ort::Value> OnlineZipformer2CtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineZipformer2CtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
return impl_->StackStates(std::move(states));
}
std::vector<std::vector<Ort::Value>> OnlineZipformer2CtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
return impl_->UnStackStates(std::move(states));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,80 @@
// sherpa-onnx/csrc/online-zipformer2-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineZipformer2CtcModel : public OnlineCtcModel {
public:
explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineZipformer2CtcModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineZipformer2CtcModel() override;
// A list of tensors.
// See also
// https://github.com/k2-fsa/icefall/pull/1413
// and
// https://github.com/k2-fsa/icefall/pull/1415
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a list of tensors
* - ans[0] contains log_probs, of shape (N, T, C)
* - ans[1:] contains next_states
*/
std::vector<Ort::Value> Forward(
Ort::Value x, std::vector<Ort::Value> states) const override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// The model accepts this number of frames before subsampling as input
int32_t ChunkLength() const override;
// Similar to frame_shift in feature extractor, after processing
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
int32_t ChunkShift() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_

View File

@@ -26,6 +26,8 @@ int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage( const char *kUsageMessage = R"usage(
Usage: Usage:
(1) Streaming transducer
./bin/sherpa-onnx \ ./bin/sherpa-onnx \
--tokens=/path/to/tokens.txt \ --tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \ --encoder=/path/to/encoder.onnx \
@@ -36,6 +38,30 @@ Usage:
--decoding-method=greedy_search \ --decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...] /path/to/foo.wav [bar.wav foobar.wav ...]
(2) Streaming zipformer2 CTC
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./bin/sherpa-onnx \
--debug=1 \
--zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav
(3) Streaming paraformer
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
./bin/sherpa-onnx \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
Note: It supports decoding multiple files in batches Note: It supports decoding multiple files in batches
Default value for num_threads is 2. Default value for num_threads is 2.

View File

@@ -8,9 +8,6 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
#include <strstream> #include <strstream>
@@ -18,6 +15,9 @@
#include "android/asset_manager_jni.h" #include "android/asset_manager_jni.h"
#endif #endif
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
SymbolTable::SymbolTable(const std::string &filename) { SymbolTable::SymbolTable(const std::string &filename) {

View File

@@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "paraformer", fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid); jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config); jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_config_cls, "encoder", fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid); s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p; ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_config_cls, "decoder", fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid); s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p; ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);

View File

@@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx
online-stream.cc online-stream.cc
online-transducer-model-config.cc online-transducer-model-config.cc
online-wenet-ctc-model-config.cc online-wenet-ctc-model-config.cc
online-zipformer2-ctc-model-config.cc
sherpa-onnx.cc sherpa-onnx.cc
silero-vad-model-config.cc silero-vad-model-config.cc
vad-model-config.cc vad-model-config.cc

View File

@@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug) .def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider) .def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type) .def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -12,6 +12,7 @@
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m); PybindOnlineTransducerModelConfig(m);
PybindOnlineParaformerModelConfig(m); PybindOnlineParaformerModelConfig(m);
PybindOnlineWenetCtcModelConfig(m); PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m);
using PyClass = OnlineModelConfig; using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig") py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &, .def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &, const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &, const std::string &, const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(), int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "") py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer) .def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("tokens", &PyClass::tokens) .def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug) .def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider) .def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type) .def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -0,0 +1,22 @@
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m) {
using PyClass = OnlineZipformer2CtcModelConfig;
py::class_<PyClass>(*m, "OnlineZipformer2CtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_

View File

@@ -8,11 +8,14 @@ from _sherpa_onnx import (
OnlineLMConfig, OnlineLMConfig,
OnlineModelConfig, OnlineModelConfig,
OnlineParaformerModelConfig, OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer, )
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
OnlineRecognizerConfig, OnlineRecognizerConfig,
OnlineStream, OnlineStream,
OnlineTransducerModelConfig, OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig, OnlineWenetCtcModelConfig,
OnlineZipformer2CtcModelConfig,
) )
@@ -272,6 +275,101 @@ class OnlineRecognizer(object):
self.config = recognizer_config self.config = recognizer_config
return self return self
@classmethod
def from_zipformer2_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
model_config = OnlineModelConfig(
zipformer2_ctc=zipformer2_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod @classmethod
def from_wenet_ctc( def from_wenet_ctc(
cls, cls,
@@ -352,7 +450,6 @@ class OnlineRecognizer(object):
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider=provider,
model_type="wenet_ctc",
) )
feat_config = FeatureExtractorConfig( feat_config = FeatureExtractorConfig(

View File

@@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}") print(f"{wave_filename}\n{result}")
print("-" * 10) print("-" * 10)
def test_zipformer2_ctc(self):
m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13"
for use_int8 in [True, False]:
name = (
"ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"
if use_int8
else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav"
wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav"
wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav"
if not Path(model).is_file():
print("skipping test_zipformer2_ctc()")
return
print(f"testing {model}")
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self): def test_wenet_ctc(self):
models = [ models = [
"sherpa-onnx-zh-wenet-aishell", "sherpa-onnx-zh-wenet-aishell",

View File

@@ -5,3 +5,4 @@ tts
vits-vctk vits-vctk
sherpa-onnx-paraformer-zh-2023-09-14 sherpa-onnx-paraformer-zh-2023-09-14
!*.sh !*.sh
*.bak

View File

@@ -60,6 +60,14 @@ func sherpaOnnxOnlineParaformerModelConfig(
) )
} }
func sherpaOnnxOnlineZipformer2CtcModelConfig(
model: String = ""
) -> SherpaOnnxOnlineZipformer2CtcModelConfig {
return SherpaOnnxOnlineZipformer2CtcModelConfig(
model: toCPointer(model)
)
}
/// Return an instance of SherpaOnnxOnlineModelConfig. /// Return an instance of SherpaOnnxOnlineModelConfig.
/// ///
/// Please refer to /// Please refer to
@@ -75,6 +83,8 @@ func sherpaOnnxOnlineModelConfig(
tokens: String, tokens: String,
transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(), transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(),
paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(), paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(),
zipformer2Ctc: SherpaOnnxOnlineZipformer2CtcModelConfig =
sherpaOnnxOnlineZipformer2CtcModelConfig(),
numThreads: Int = 1, numThreads: Int = 1,
provider: String = "cpu", provider: String = "cpu",
debug: Int = 0, debug: Int = 0,
@@ -83,6 +93,7 @@ func sherpaOnnxOnlineModelConfig(
return SherpaOnnxOnlineModelConfig( return SherpaOnnxOnlineModelConfig(
transducer: transducer, transducer: transducer,
paraformer: paraformer, paraformer: paraformer,
zipformer2_ctc: zipformer2Ctc,
tokens: toCPointer(tokens), tokens: toCPointer(tokens),
num_threads: Int32(numThreads), num_threads: Int32(numThreads),
provider: toCPointer(provider), provider: toCPointer(provider),

View File

@@ -13,24 +13,47 @@ extension AVAudioPCMBuffer {
} }
func run() { func run() {
let encoder = var modelConfig: SherpaOnnxOnlineModelConfig
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" var modelType = "zipformer2-ctc"
let decoder = var filePath: String
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
let joiner =
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
let transducerConfig = sherpaOnnxOnlineTransducerModelConfig( modelType = "transducer"
encoder: encoder,
decoder: decoder,
joiner: joiner
)
let modelConfig = sherpaOnnxOnlineModelConfig( if modelType == "transducer" {
tokens: tokens, filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
transducer: transducerConfig let encoder =
) "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
let decoder =
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
let joiner =
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
let transducerConfig = sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner
)
modelConfig = sherpaOnnxOnlineModelConfig(
tokens: tokens,
transducer: transducerConfig
)
} else {
filePath =
"./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
let model =
"./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
let tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt"
let zipfomer2CtcModelConfig = sherpaOnnxOnlineZipformer2CtcModelConfig(
model: model
)
modelConfig = sherpaOnnxOnlineModelConfig(
tokens: tokens,
zipformer2Ctc: zipfomer2CtcModelConfig
)
}
let featConfig = sherpaOnnxFeatureConfig( let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000, sampleRate: 16000,
@@ -43,7 +66,6 @@ func run() {
let recognizer = SherpaOnnxRecognizer(config: &config) let recognizer = SherpaOnnxRecognizer(config: &config)
let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
let fileURL: NSURL = NSURL(fileURLWithPath: filePath) let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
let audioFile = try! AVAudioFile(forReading: fileURL as URL) let audioFile = try! AVAudioFile(forReading: fileURL as URL)

View File

@@ -20,6 +20,12 @@ if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then
rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
fi fi
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
if [ ! -e ./decode-file ]; then if [ ! -e ./decode-file ]; then
# Note: We use -lc++ to link against libc++ instead of libstdc++ # Note: We use -lc++ to link against libc++ instead of libstdc++
swiftc \ swiftc \

View File

@@ -22,7 +22,7 @@ if [ ! -d ./sherpa-onnx-whisper-tiny.en ]; then
fi fi
if [ ! -f ./silero_vad.onnx ]; then if [ ! -f ./silero_vad.onnx ]; then
echo "downloading silero_vad" echo "downloading silero_vad"
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
fi fi
if [ ! -e ./generate-subtitles ]; then if [ ! -e ./generate-subtitles ]; then