Begin to support CTC models (#119)
Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
47
.github/scripts/test-offline-ctc.sh
vendored
Executable file
47
.github/scripts/test-offline-ctc.sh
vendored
Executable file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run Citrinet (stt_en_citrinet_512, English)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--nemo-ctc-model=$repo/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--nemo-ctc-model=$repo/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
38
.github/scripts/test-python.sh
vendored
38
.github/scripts/test-python.sh
vendored
@@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \
|
|||||||
|
|
||||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
log "Test non-streaming paraformer models"
|
log "Test non-streaming paraformer models"
|
||||||
|
|
||||||
pushd $dir
|
pushd $dir
|
||||||
@@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \
|
|||||||
$repo/test_wavs/8k.wav
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "Test non-streaming NeMo CTC models"
|
||||||
|
|
||||||
|
pushd $dir
|
||||||
|
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||||
|
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$dir/$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
cd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--nemo-ctc=$repo/model.onnx \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--nemo-ctc=$repo/model.int8.onnx \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|||||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- '.github/workflows/linux.yaml'
|
- '.github/workflows/linux.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -20,6 +21,7 @@ on:
|
|||||||
- '.github/workflows/linux.yaml'
|
- '.github/workflows/linux.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -68,6 +70,14 @@ jobs:
|
|||||||
file build/bin/sherpa-onnx
|
file build/bin/sherpa-onnx
|
||||||
readelf -d build/bin/sherpa-onnx
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline CTC
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline transducer
|
- name: Test offline transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- '.github/workflows/macos.yaml'
|
- '.github/workflows/macos.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -18,6 +19,7 @@ on:
|
|||||||
- '.github/workflows/macos.yaml'
|
- '.github/workflows/macos.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -67,6 +69,14 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline CTC
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline transducer
|
- name: Test offline transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/windows-x64.yaml
vendored
10
.github/workflows/windows-x64.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- '.github/workflows/windows-x64.yaml'
|
- '.github/workflows/windows-x64.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -18,6 +19,7 @@ on:
|
|||||||
- '.github/workflows/windows-x64.yaml'
|
- '.github/workflows/windows-x64.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -73,6 +75,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test offline CTC for windows x64
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline.exe
|
||||||
|
|
||||||
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline transducer for Windows x64
|
- name: Test offline transducer for Windows x64
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
11
.github/workflows/windows-x86.yaml
vendored
11
.github/workflows/windows-x86.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- '.github/workflows/windows-x86.yaml'
|
- '.github/workflows/windows-x86.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -18,6 +19,7 @@ on:
|
|||||||
- '.github/workflows/windows-x86.yaml'
|
- '.github/workflows/windows-x86.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
- 'cmake/**'
|
- 'cmake/**'
|
||||||
- 'sherpa-onnx/csrc/*'
|
- 'sherpa-onnx/csrc/*'
|
||||||
@@ -31,6 +33,7 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
windows_x86:
|
windows_x86:
|
||||||
|
if: false # disable windows x86 CI for now
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
name: ${{ matrix.vs-version }}
|
name: ${{ matrix.vs-version }}
|
||||||
strategy:
|
strategy:
|
||||||
@@ -73,6 +76,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test offline CTC for windows x86
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline.exe
|
||||||
|
|
||||||
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline transducer for Windows x86
|
- name: Test offline transducer for Windows x86
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh
|
|||||||
run-sherpa-onnx-*.sh
|
run-sherpa-onnx-*.sh
|
||||||
sherpa-onnx-zipformer-en-2023-03-30
|
sherpa-onnx-zipformer-en-2023-03-30
|
||||||
sherpa-onnx-zipformer-en-2023-04-01
|
sherpa-onnx-zipformer-en-2023-04-01
|
||||||
|
run-offline-decode-files.sh
|
||||||
|
sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||||
|
run-offline-decode-files-nemo-ctc.sh
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
file(s) with a non-streaming model.
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
paraformer Usage:
|
(1) For paraformer
|
||||||
./python-api-examples/offline-decode-files.py \
|
./python-api-examples/offline-decode-files.py \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--paraformer=/path/to/paraformer.onnx \
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
@@ -18,7 +18,7 @@ paraformer Usage:
|
|||||||
/path/to/0.wav \
|
/path/to/0.wav \
|
||||||
/path/to/1.wav
|
/path/to/1.wav
|
||||||
|
|
||||||
transducer Usage:
|
(2) For transducer models from icefall
|
||||||
./python-api-examples/offline-decode-files.py \
|
./python-api-examples/offline-decode-files.py \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--encoder=/path/to/encoder.onnx \
|
--encoder=/path/to/encoder.onnx \
|
||||||
@@ -32,6 +32,8 @@ transducer Usage:
|
|||||||
/path/to/0.wav \
|
/path/to/0.wav \
|
||||||
/path/to/1.wav
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
to install sherpa-onnx and to download the pre-trained models
|
to install sherpa-onnx and to download the pre-trained models
|
||||||
@@ -83,7 +85,14 @@ def get_args():
|
|||||||
"--paraformer",
|
"--paraformer",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the paraformer model",
|
help="Path to the model.onnx from Paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nemo-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -171,11 +180,14 @@ def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
assert_file_exists(args.tokens)
|
assert_file_exists(args.tokens)
|
||||||
assert args.num_threads > 0, args.num_threads
|
assert args.num_threads > 0, args.num_threads
|
||||||
if len(args.encoder) > 0:
|
if args.encoder:
|
||||||
|
assert len(args.paraformer) == 0, args.paraformer
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
assert_file_exists(args.joiner)
|
assert_file_exists(args.joiner)
|
||||||
assert len(args.paraformer) == 0, args.paraformer
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||||
encoder=args.encoder,
|
encoder=args.encoder,
|
||||||
decoder=args.decoder,
|
decoder=args.decoder,
|
||||||
@@ -187,8 +199,10 @@ def main():
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
else:
|
elif args.paraformer:
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
paraformer=args.paraformer,
|
paraformer=args.paraformer,
|
||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
@@ -198,6 +212,19 @@ def main():
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
|
elif args.nemo_ctc:
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||||
|
model=args.nemo_ctc,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Please specify at least one model")
|
||||||
|
return
|
||||||
|
|
||||||
print("Started!")
|
print("Started!")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -225,12 +252,14 @@ def main():
|
|||||||
print("-" * 10)
|
print("-" * 10)
|
||||||
|
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
rtf = elapsed_seconds / duration
|
rtf = elapsed_seconds / total_duration
|
||||||
print(f"num_threads: {args.num_threads}")
|
print(f"num_threads: {args.num_threads}")
|
||||||
print(f"decoding_method: {args.decoding_method}")
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
print(f"Wave duration: {duration:.3f} s")
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -172,12 +172,14 @@ def main():
|
|||||||
print("-" * 10)
|
print("-" * 10)
|
||||||
|
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
rtf = elapsed_seconds / duration
|
rtf = elapsed_seconds / total_duration
|
||||||
print(f"num_threads: {args.num_threads}")
|
print(f"num_threads: {args.num_threads}")
|
||||||
print(f"decoding_method: {args.decoding_method}")
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
print(f"Wave duration: {duration:.3f} s")
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -16,7 +16,11 @@ set(sources
|
|||||||
features.cc
|
features.cc
|
||||||
file-utils.cc
|
file-utils.cc
|
||||||
hypothesis.cc
|
hypothesis.cc
|
||||||
|
offline-ctc-greedy-search-decoder.cc
|
||||||
|
offline-ctc-model.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
|
offline-nemo-enc-dec-ctc-model.cc
|
||||||
offline-paraformer-greedy-search-decoder.cc
|
offline-paraformer-greedy-search-decoder.cc
|
||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
offline-paraformer-model.cc
|
offline-paraformer-model.cc
|
||||||
|
|||||||
@@ -11,15 +11,19 @@
|
|||||||
#include "android/log.h"
|
#include "android/log.h"
|
||||||
#define SHERPA_ONNX_LOGE(...) \
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
do { \
|
do { \
|
||||||
|
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
|
||||||
|
static_cast<int>(__LINE__)); \
|
||||||
fprintf(stderr, ##__VA_ARGS__); \
|
fprintf(stderr, ##__VA_ARGS__); \
|
||||||
fprintf(stderr, "\n"); \
|
fprintf(stderr, "\n"); \
|
||||||
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
||||||
} while (0)
|
} while (0)
|
||||||
#else
|
#else
|
||||||
#define SHERPA_ONNX_LOGE(...) \
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
do { \
|
do { \
|
||||||
fprintf(stderr, ##__VA_ARGS__); \
|
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
|
||||||
fprintf(stderr, "\n"); \
|
static_cast<int>(__LINE__)); \
|
||||||
|
fprintf(stderr, ##__VA_ARGS__); \
|
||||||
|
fprintf(stderr, "\n"); \
|
||||||
} while (0)
|
} while (0)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
42
sherpa-onnx/csrc/offline-ctc-decoder.h
Normal file
42
sherpa-onnx/csrc/offline-ctc-decoder.h
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ctc-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineCtcDecoderResult {
|
||||||
|
/// The decoded token IDs
|
||||||
|
std::vector<int64_t> tokens;
|
||||||
|
|
||||||
|
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||||
|
/// Note: The index is after subsampling
|
||||||
|
std::vector<int32_t> timestamps;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OfflineCtcDecoder {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineCtcDecoder() = default;
|
||||||
|
|
||||||
|
/** Run CTC decoding given the output from the encoder model.
|
||||||
|
*
|
||||||
|
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
|
||||||
|
* lob_probs.
|
||||||
|
* @param log_probs_length A 1-D tensor of shape (N,) containing number
|
||||||
|
* of valid frames in log_probs before padding.
|
||||||
|
*
|
||||||
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
|
*/
|
||||||
|
virtual std::vector<OfflineCtcDecoderResult> Decode(
|
||||||
|
Ort::Value log_probs, Ort::Value log_probs_length) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||||
54
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Normal file
54
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
|
||||||
|
Ort::Value log_probs, Ort::Value log_probs_length) {
|
||||||
|
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||||
|
int32_t num_frames = static_cast<int32_t>(shape[1]);
|
||||||
|
int32_t vocab_size = static_cast<int32_t>(shape[2]);
|
||||||
|
|
||||||
|
const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>();
|
||||||
|
|
||||||
|
std::vector<OfflineCtcDecoderResult> ans;
|
||||||
|
ans.reserve(batch_size);
|
||||||
|
|
||||||
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
|
const float *p_log_probs =
|
||||||
|
log_probs.GetTensorData<float>() + b * num_frames * vocab_size;
|
||||||
|
|
||||||
|
OfflineCtcDecoderResult r;
|
||||||
|
int64_t prev_id = -1;
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != static_cast<int32_t>(p_log_probs_length[b]); ++t) {
|
||||||
|
auto y = static_cast<int64_t>(std::distance(
|
||||||
|
static_cast<const float *>(p_log_probs),
|
||||||
|
std::max_element(
|
||||||
|
static_cast<const float *>(p_log_probs),
|
||||||
|
static_cast<const float *>(p_log_probs) + vocab_size)));
|
||||||
|
p_log_probs += vocab_size;
|
||||||
|
|
||||||
|
if (y != blank_id_ && y != prev_id) {
|
||||||
|
r.tokens.push_back(y);
|
||||||
|
r.timestamps.push_back(t);
|
||||||
|
prev_id = y;
|
||||||
|
}
|
||||||
|
} // for (int32_t t = 0; ...)
|
||||||
|
|
||||||
|
ans.push_back(std::move(r));
|
||||||
|
}
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
28
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Normal file
28
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
|
||||||
|
public:
|
||||||
|
explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
|
||||||
|
: blank_id_(blank_id) {}
|
||||||
|
|
||||||
|
std::vector<OfflineCtcDecoderResult> Decode(
|
||||||
|
Ort::Value log_probs, Ort::Value log_probs_length) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t blank_id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||||
86
sherpa-onnx/csrc/offline-ctc-model.cc
Normal file
86
sherpa-onnx/csrc/offline-ctc-model.cc
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ctc-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum class ModelType {
|
||||||
|
kEncDecCTCModelBPE,
|
||||||
|
kUnkown,
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||||
|
bool debug) {
|
||||||
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
|
Ort::SessionOptions sess_opts;
|
||||||
|
|
||||||
|
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||||
|
sess_opts);
|
||||||
|
|
||||||
|
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||||
|
if (debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
auto model_type =
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||||
|
if (!model_type) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"No model_type in the metadata!\n"
|
||||||
|
"If you are using models from NeMo, please refer to\n"
|
||||||
|
"https://huggingface.co/csukuangfj/"
|
||||||
|
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||||
|
"\n"
|
||||||
|
"for how to add metadta to model.onnx\n");
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||||
|
return ModelType::kEncDecCTCModelBPE;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||||
|
const OfflineModelConfig &config) {
|
||||||
|
ModelType model_type = ModelType::kUnkown;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buffer = ReadFile(config.nemo_ctc.model);
|
||||||
|
|
||||||
|
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (model_type) {
|
||||||
|
case ModelType::kEncDecCTCModelBPE:
|
||||||
|
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
|
||||||
|
break;
|
||||||
|
case ModelType::kUnkown:
|
||||||
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
59
sherpa-onnx/csrc/offline-ctc-model.h
Normal file
59
sherpa-onnx/csrc/offline-ctc-model.h
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ctc-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineCtcModel {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineCtcModel() = default;
|
||||||
|
static std::unique_ptr<OfflineCtcModel> Create(
|
||||||
|
const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int64_t.
|
||||||
|
*
|
||||||
|
* @return Return a pair containing:
|
||||||
|
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||||
|
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||||
|
*/
|
||||||
|
virtual std::pair<Ort::Value, Ort::Value> Forward(
|
||||||
|
Ort::Value features, Ort::Value features_length) = 0;
|
||||||
|
|
||||||
|
/** Return the vocabulary size of the model
|
||||||
|
*/
|
||||||
|
virtual int32_t VocabSize() const = 0;
|
||||||
|
|
||||||
|
/** SubsamplingFactor of the model
|
||||||
|
*
|
||||||
|
* For Citrinet, the subsampling factor is usually 4.
|
||||||
|
* For Conformer CTC, the subsampling factor is usually 8.
|
||||||
|
*/
|
||||||
|
virtual int32_t SubsamplingFactor() const = 0;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
virtual OrtAllocator *Allocator() const = 0;
|
||||||
|
|
||||||
|
/** For some models, e.g., those from NeMo, they require some preprocessing
|
||||||
|
* for the features.
|
||||||
|
*/
|
||||||
|
virtual std::string FeatureNormalizationMethod() const { return {}; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||||
@@ -13,6 +13,7 @@ namespace sherpa_onnx {
|
|||||||
void OfflineModelConfig::Register(ParseOptions *po) {
|
void OfflineModelConfig::Register(ParseOptions *po) {
|
||||||
transducer.Register(po);
|
transducer.Register(po);
|
||||||
paraformer.Register(po);
|
paraformer.Register(po);
|
||||||
|
nemo_ctc.Register(po);
|
||||||
|
|
||||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||||
|
|
||||||
@@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return paraformer.Validate();
|
return paraformer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!nemo_ctc.model.empty()) {
|
||||||
|
return nemo_ctc.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
return transducer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "OfflineModelConfig(";
|
os << "OfflineModelConfig(";
|
||||||
os << "transducer=" << transducer.ToString() << ", ";
|
os << "transducer=" << transducer.ToString() << ", ";
|
||||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||||
|
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ namespace sherpa_onnx {
|
|||||||
struct OfflineModelConfig {
|
struct OfflineModelConfig {
|
||||||
OfflineTransducerModelConfig transducer;
|
OfflineTransducerModelConfig transducer;
|
||||||
OfflineParaformerModelConfig paraformer;
|
OfflineParaformerModelConfig paraformer;
|
||||||
|
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 2;
|
int32_t num_threads = 2;
|
||||||
@@ -22,9 +24,11 @@ struct OfflineModelConfig {
|
|||||||
OfflineModelConfig() = default;
|
OfflineModelConfig() = default;
|
||||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||||
const OfflineParaformerModelConfig ¶former,
|
const OfflineParaformerModelConfig ¶former,
|
||||||
|
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug)
|
const std::string &tokens, int32_t num_threads, bool debug)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
|
nemo_ctc(nemo_ctc),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug) {}
|
debug(debug) {}
|
||||||
|
|||||||
35
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
Normal file
35
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("nemo-ctc-model", &model,
|
||||||
|
"Path to model.onnx of Nemo EncDecCtcModel.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineNemoEncDecCtcModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineNemoEncDecCtcModelConfig(";
|
||||||
|
os << "model=\"" << model << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
28
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineNemoEncDecCtcModelConfig {
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
OfflineNemoEncDecCtcModelConfig() = default;
|
||||||
|
explicit OfflineNemoEncDecCtcModelConfig(const std::string &model)
|
||||||
|
: model(model) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
131
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
Normal file
131
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineNemoEncDecCtcModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_{},
|
||||||
|
allocator_{} {
|
||||||
|
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
|
||||||
|
sess_opts_.SetInterOpNumThreads(config_.num_threads);
|
||||||
|
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||||
|
Ort::Value features_length) {
|
||||||
|
std::vector<int64_t> shape =
|
||||||
|
features_length.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
Ort::Value out_features_length = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
allocator_, shape.data(), shape.size());
|
||||||
|
|
||||||
|
const int64_t *src = features_length.GetTensorData<int64_t>();
|
||||||
|
int64_t *dst = out_features_length.GetTensorMutableData<int64_t>();
|
||||||
|
for (int64_t i = 0; i != shape[0]; ++i) {
|
||||||
|
dst[i] = src[i] / subsampling_factor_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// (B, T, C) -> (B, C, T)
|
||||||
|
features = Transpose12(allocator_, &features);
|
||||||
|
|
||||||
|
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||||
|
std::move(features_length)};
|
||||||
|
auto out =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
|
||||||
|
return {std::move(out[0]), std::move(out_features_length)};
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init() {
|
||||||
|
auto buf = ReadFile(config_.nemo_ctc.model);
|
||||||
|
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
int32_t subsampling_factor_ = 0;
|
||||||
|
std::string normalize_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||||
|
Ort::Value features, Ort::Value features_length) {
|
||||||
|
return impl_->Forward(std::move(features), std::move(features_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineNemoEncDecCtcModel::VocabSize() const {
|
||||||
|
return impl_->VocabSize();
|
||||||
|
}
|
||||||
|
int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const {
|
||||||
|
return impl_->SubsamplingFactor();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
|
||||||
|
return impl_->FeatureNormalizationMethod();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
75
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
Normal file
75
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** This class implements the EncDecCTCModelBPE model from NeMo.
|
||||||
|
*
|
||||||
|
* See
|
||||||
|
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py
|
||||||
|
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py
|
||||||
|
*/
|
||||||
|
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
|
||||||
|
~OfflineNemoEncDecCtcModel() override;
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int64_t.
|
||||||
|
*
|
||||||
|
* @return Return a pair containing:
|
||||||
|
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||||
|
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> Forward(
|
||||||
|
Ort::Value features, Ort::Value features_length) override;
|
||||||
|
|
||||||
|
/** Return the vocabulary size of the model
|
||||||
|
*/
|
||||||
|
int32_t VocabSize() const override;
|
||||||
|
|
||||||
|
/** SubsamplingFactor of the model
|
||||||
|
*
|
||||||
|
* For Citrinet, the subsampling factor is usually 4.
|
||||||
|
* For Conformer CTC, the subsampling factor is usually 8.
|
||||||
|
*/
|
||||||
|
int32_t SubsamplingFactor() const override;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const override;
|
||||||
|
|
||||||
|
// Possible values:
|
||||||
|
// - per_feature
|
||||||
|
// - all_features (not implemented yet)
|
||||||
|
// - fixed_mean (not implemented)
|
||||||
|
// - fixed_std (not implemented)
|
||||||
|
// - or just leave it to empty
|
||||||
|
// See
|
||||||
|
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||||
|
// for details
|
||||||
|
std::string FeatureNormalizationMethod() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||||
128
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Normal file
128
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table) {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||||
|
auto sym = sym_table[src.tokens[i]];
|
||||||
|
text.append(sym);
|
||||||
|
r.tokens.push_back(std::move(sym));
|
||||||
|
}
|
||||||
|
r.text = std::move(text);
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(OfflineCtcModel::Create(config_.model_config)) {
|
||||||
|
config_.feat_config.nemo_normalize_type =
|
||||||
|
model_->FeatureNormalizationMethod();
|
||||||
|
|
||||||
|
if (config.decoding_method == "greedy_search") {
|
||||||
|
if (!symbol_table_.contains("<blk>")) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"We expect that tokens.txt contains "
|
||||||
|
"the symbol <blk> and its ID.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t blank_id = symbol_table_["<blk>"];
|
||||||
|
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> features;
|
||||||
|
features.reserve(n);
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> features_vec(n);
|
||||||
|
std::vector<int64_t> features_length_vec(n);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
std::vector<float> f = ss[i]->GetFrames();
|
||||||
|
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
features_vec[i] = std::move(f);
|
||||||
|
|
||||||
|
features_length_vec[i] = num_frames;
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
features.push_back(std::move(x));
|
||||||
|
} // for (int32_t i = 0; i != n; ++i)
|
||||||
|
|
||||||
|
std::vector<const Ort::Value *> features_pointer(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
features_pointer[i] = &features[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int64_t, 1> features_length_shape = {n};
|
||||||
|
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_length_vec.data(), n,
|
||||||
|
features_length_shape.data(), features_length_shape.size());
|
||||||
|
|
||||||
|
Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
|
||||||
|
-23.025850929940457f);
|
||||||
|
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||||
|
|
||||||
|
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
auto r = Convert(results[i], symbol_table_);
|
||||||
|
ss[i]->SetResult(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineCtcModel> model_;
|
||||||
|
std::unique_ptr<OfflineCtcDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
@@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_filename = config.model_config.transducer.encoder_filename;
|
model_filename = config.model_config.transducer.encoder_filename;
|
||||||
} else if (!config.model_config.paraformer.model.empty()) {
|
} else if (!config.model_config.paraformer.model.empty()) {
|
||||||
model_filename = config.model_config.paraformer.model;
|
model_filename = config.model_config.paraformer.model;
|
||||||
|
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||||
|
model_filename = config.model_config.nemo_ctc.model;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Please provide a model");
|
SHERPA_ONNX_LOGE("Please provide a model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
|
||||||
std::string model_type;
|
auto model_type_ptr =
|
||||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||||
|
if (!model_type_ptr) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"No model_type in the metadata!\n\n"
|
||||||
|
"Please refer to the following URLs to add metadata"
|
||||||
|
"\n"
|
||||||
|
"(0) Transducer models from icefall"
|
||||||
|
"\n "
|
||||||
|
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||||
|
"pruned_transducer_stateless7/export-onnx.py#L303"
|
||||||
|
"\n"
|
||||||
|
"(1) Nemo CTC models\n "
|
||||||
|
"https://huggingface.co/csukuangfj/"
|
||||||
|
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||||
|
"\n"
|
||||||
|
"(2) Paraformer"
|
||||||
|
"\n "
|
||||||
|
"https://huggingface.co/csukuangfj/"
|
||||||
|
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||||
|
"\n");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
std::string model_type(model_type_ptr.get());
|
||||||
|
|
||||||
if (model_type == "conformer" || model_type == "zipformer") {
|
if (model_type == "conformer" || model_type == "zipformer") {
|
||||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||||
@@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_type == "EncDecCTCModelBPE") {
|
||||||
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"\nUnsupported model_type: %s\n"
|
"\nUnsupported model_type: %s\n"
|
||||||
"We support only the following model types at present: \n"
|
"We support only the following model types at present: \n"
|
||||||
" - transducer models from icefall\n"
|
" - Non-streaming transducer models from icefall\n"
|
||||||
" - Paraformer models from FunASR\n",
|
" - Non-streaming Paraformer models from FunASR\n"
|
||||||
|
" - EncDecCTCModelBPE models from NeMo\n",
|
||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -15,6 +16,41 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/* Compute mean and inverse stddev over rows.
|
||||||
|
*
|
||||||
|
* @param p A pointer to a 2-d array of shape (num_rows, num_cols)
|
||||||
|
* @param num_rows Number of rows
|
||||||
|
* @param num_cols Number of columns
|
||||||
|
* @param mean On return, it contains p.mean(axis=0)
|
||||||
|
* @param inv_stddev On return, it contains 1/p.std(axis=0)
|
||||||
|
*/
|
||||||
|
static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
|
||||||
|
int32_t num_cols, std::vector<float> *mean,
|
||||||
|
std::vector<float> *inv_stddev) {
|
||||||
|
std::vector<float> sum(num_cols);
|
||||||
|
std::vector<float> sum_sq(num_cols);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_rows; ++i) {
|
||||||
|
for (int32_t c = 0; c != num_cols; ++c) {
|
||||||
|
auto t = p[c];
|
||||||
|
sum[c] += t;
|
||||||
|
sum_sq[c] += t * t;
|
||||||
|
}
|
||||||
|
p += num_cols;
|
||||||
|
}
|
||||||
|
|
||||||
|
mean->resize(num_cols);
|
||||||
|
inv_stddev->resize(num_cols);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_cols; ++i) {
|
||||||
|
auto t = sum[i] / num_rows;
|
||||||
|
(*mean)[i] = t;
|
||||||
|
|
||||||
|
float stddev = std::sqrt(sum_sq[i] / num_rows - t * t);
|
||||||
|
(*inv_stddev)[i] = 1.0f / (stddev + 1e-5f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
||||||
po->Register("sample-rate", &sampling_rate,
|
po->Register("sample-rate", &sampling_rate,
|
||||||
"Sampling rate of the input waveform. "
|
"Sampling rate of the input waveform. "
|
||||||
@@ -106,6 +142,8 @@ class OfflineStream::Impl {
|
|||||||
p += feature_dim;
|
p += feature_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NemoNormalizeFeatures(features.data(), n, feature_dim);
|
||||||
|
|
||||||
return features;
|
return features;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,6 +151,38 @@ class OfflineStream::Impl {
|
|||||||
|
|
||||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||||
|
int32_t feature_dim) const {
|
||||||
|
if (config_.nemo_normalize_type.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.nemo_normalize_type != "per_feature") {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only normalize_type=per_feature is implemented. Given: %s",
|
||||||
|
config_.nemo_normalize_type.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
NemoNormalizePerFeature(p, num_frames, feature_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void NemoNormalizePerFeature(float *p, int32_t num_frames,
|
||||||
|
int32_t feature_dim) {
|
||||||
|
std::vector<float> mean;
|
||||||
|
std::vector<float> inv_stddev;
|
||||||
|
|
||||||
|
ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev);
|
||||||
|
|
||||||
|
for (int32_t n = 0; n != num_frames; ++n) {
|
||||||
|
for (int32_t i = 0; i != feature_dim; ++i) {
|
||||||
|
p[i] = (p[i] - mean[i]) * inv_stddev[i];
|
||||||
|
}
|
||||||
|
p += feature_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineFeatureExtractorConfig config_;
|
OfflineFeatureExtractorConfig config_;
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
|
|||||||
@@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig {
|
|||||||
// Feature dimension
|
// Feature dimension
|
||||||
int32_t feature_dim = 80;
|
int32_t feature_dim = 80;
|
||||||
|
|
||||||
// Set internally by some models, e.g., paraformer
|
// Set internally by some models, e.g., paraformer sets it to false.
|
||||||
// This parameter is not exposed to users from the commandline
|
// This parameter is not exposed to users from the commandline
|
||||||
// If true, the feature extractor expects inputs to be normalized to
|
// If true, the feature extractor expects inputs to be normalized to
|
||||||
// the range [-1, 1].
|
// the range [-1, 1].
|
||||||
// If false, we will multiply the inputs by 32768
|
// If false, we will multiply the inputs by 32768
|
||||||
bool normalize_samples = true;
|
bool normalize_samples = true;
|
||||||
|
|
||||||
|
// For models from NeMo
|
||||||
|
// This option is not exposed and is set internally when loading models.
|
||||||
|
// Possible values:
|
||||||
|
// - per_feature
|
||||||
|
// - all_features (not implemented yet)
|
||||||
|
// - fixed_mean (not implemented)
|
||||||
|
// - fixed_std (not implemented)
|
||||||
|
// - or just leave it to empty
|
||||||
|
// See
|
||||||
|
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||||
|
// for details
|
||||||
|
std::string nemo_normalize_type;
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||||
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
namespace sherpa_onnx {
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
enum class ModelType {
|
enum class ModelType {
|
||||||
kLstm,
|
kLstm,
|
||||||
@@ -25,6 +27,10 @@ enum class ModelType {
|
|||||||
kUnkown,
|
kUnkown,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||||
bool debug) {
|
bool debug) {
|
||||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
@@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
if (debug) {
|
if (debug) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
PrintModelMetadata(os, meta_data);
|
PrintModelMetadata(os, meta_data);
|
||||||
fprintf(stderr, "%s\n", os.str().c_str());
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto model_type =
|
auto model_type =
|
||||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||||
if (!model_type) {
|
if (!model_type) {
|
||||||
fprintf(stderr, "No model_type in the metadata!\n");
|
SHERPA_ONNX_LOGE(
|
||||||
|
"No model_type in the metadata!\n"
|
||||||
|
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||||
|
"to export your transducer models");
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
} else if (model_type.get() == std::string("zipformer")) {
|
} else if (model_type.get() == std::string("zipformer")) {
|
||||||
return ModelType::kZipformer;
|
return ModelType::kZipformer;
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
case ModelType::kZipformer:
|
case ModelType::kZipformer:
|
||||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
|
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
case ModelType::kZipformer:
|
case ModelType::kZipformer:
|
||||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
|
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Tranpose, Tranpose12) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
std::array<int64_t, 3> shape{3, 2, 5};
|
||||||
|
Ort::Value v =
|
||||||
|
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||||
|
float *p = v.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
|
||||||
|
|
||||||
|
auto ans = Transpose12(allocator, &v);
|
||||||
|
auto v2 = Transpose12(allocator, &ans);
|
||||||
|
|
||||||
|
Print3D(&v);
|
||||||
|
Print3D(&ans);
|
||||||
|
Print3D(&v2);
|
||||||
|
|
||||||
|
const float *q = v2.GetTensorData<float>();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||||
|
++i) {
|
||||||
|
EXPECT_EQ(p[i], q[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
|||||||
assert(shape.size() == 3);
|
assert(shape.size() == 3);
|
||||||
|
|
||||||
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
|
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
|
||||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||||
ans_shape.size());
|
ans_shape.size());
|
||||||
|
|
||||||
T *dst = ans.GetTensorMutableData<T>();
|
T *dst = ans.GetTensorMutableData<T>();
|
||||||
auto plane_offset = shape[1] * shape[2];
|
auto plane_offset = shape[1] * shape[2];
|
||||||
@@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T /*= float*/>
|
||||||
|
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) {
|
||||||
|
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
assert(shape.size() == 3);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> ans_shape{shape[0], shape[2], shape[1]};
|
||||||
|
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||||
|
ans_shape.size());
|
||||||
|
T *dst = ans.GetTensorMutableData<T>();
|
||||||
|
auto row_stride = shape[2];
|
||||||
|
for (int64_t b = 0; b != ans_shape[0]; ++b) {
|
||||||
|
const T *src = v->GetTensorData<T>() + b * shape[1] * shape[2];
|
||||||
|
for (int64_t i = 0; i != ans_shape[1]; ++i) {
|
||||||
|
for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) {
|
||||||
|
*dst = (src + k * row_stride)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
|
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
|
||||||
const Ort::Value *v);
|
const Ort::Value *v);
|
||||||
|
|
||||||
|
template Ort::Value Transpose12<float>(OrtAllocator *allocator,
|
||||||
|
const Ort::Value *v);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -10,13 +10,23 @@ namespace sherpa_onnx {
|
|||||||
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
|
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
|
||||||
*
|
*
|
||||||
* @param allocator
|
* @param allocator
|
||||||
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
|
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
|
||||||
*
|
*
|
||||||
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
|
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is type.
|
||||||
*/
|
*/
|
||||||
template <typename T = float>
|
template <typename type = float>
|
||||||
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
|
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
|
||||||
|
|
||||||
|
/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
|
||||||
|
*
|
||||||
|
* @param allocator
|
||||||
|
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
|
||||||
|
*
|
||||||
|
* @return Return a 3-D tensor of shape (B, C, T). Its datatype is type.
|
||||||
|
*/
|
||||||
|
template <typename type = float>
|
||||||
|
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
endpoint.cc
|
endpoint.cc
|
||||||
features.cc
|
features.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
|
|||||||
@@ -7,26 +7,31 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void PybindOfflineModelConfig(py::module *m) {
|
void PybindOfflineModelConfig(py::module *m) {
|
||||||
PybindOfflineTransducerModelConfig(m);
|
PybindOfflineTransducerModelConfig(m);
|
||||||
PybindOfflineParaformerModelConfig(m);
|
PybindOfflineParaformerModelConfig(m);
|
||||||
|
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
.def(py::init<OfflineTransducerModelConfig &,
|
.def(py::init<const OfflineTransducerModelConfig &,
|
||||||
OfflineParaformerModelConfig &,
|
const OfflineParaformerModelConfig &,
|
||||||
const std::string &, int32_t, bool>(),
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"),
|
const std::string &, int32_t, bool>(),
|
||||||
py::arg("num_threads"), py::arg("debug") = false)
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
|
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||||
|
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||||
|
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false)
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
|
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineNemoEncDecCtcModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineNemoEncDecCtcModelConfig")
|
||||||
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -15,8 +14,7 @@ namespace sherpa_onnx {
|
|||||||
void PybindOfflineParaformerModelConfig(py::module *m) {
|
void PybindOfflineParaformerModelConfig(py::module *m) {
|
||||||
using PyClass = OfflineParaformerModelConfig;
|
using PyClass = OfflineParaformerModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
|
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
|
||||||
.def(py::init<const std::string &>(),
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
py::arg("model"))
|
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,6 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static void PybindOfflineRecognizerConfig(py::module *m) {
|
static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||||
using PyClass = OfflineRecognizerConfig;
|
using PyClass = OfflineRecognizerConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
|||||||
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
||||||
using PyClass = OfflineFeatureExtractorConfig;
|
using PyClass = OfflineFeatureExtractorConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
|
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
|
||||||
@@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
|||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void PybindOfflineStream(py::module *m) {
|
void PybindOfflineStream(py::module *m) {
|
||||||
PybindOfflineFeatureExtractorConfig(m);
|
PybindOfflineFeatureExtractorConfig(m);
|
||||||
PybindOfflineRecognitionResult(m);
|
PybindOfflineRecognitionResult(m);
|
||||||
@@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) {
|
|||||||
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
||||||
},
|
},
|
||||||
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
|
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
|
||||||
.def_property_readonly("result", &PyClass::GetResult);
|
.def_property_readonly("result", &PyClass::GetResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -7,16 +7,13 @@
|
|||||||
#include "sherpa-onnx/python/csrc/display.h"
|
#include "sherpa-onnx/python/csrc/display.h"
|
||||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||||
#include "sherpa-onnx/python/csrc/features.h"
|
#include "sherpa-onnx/python/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
PYBIND11_MODULE(_sherpa_onnx, m) {
|
PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ from typing import List
|
|||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
OfflineFeatureExtractorConfig,
|
OfflineFeatureExtractorConfig,
|
||||||
OfflineRecognizer as _Recognizer,
|
OfflineModelConfig,
|
||||||
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
|
OfflineParaformerModelConfig,
|
||||||
|
)
|
||||||
|
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||||
|
from _sherpa_onnx import (
|
||||||
OfflineRecognizerConfig,
|
OfflineRecognizerConfig,
|
||||||
OfflineStream,
|
OfflineStream,
|
||||||
OfflineModelConfig,
|
|
||||||
OfflineTransducerModelConfig,
|
OfflineTransducerModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +78,6 @@ class OfflineRecognizer(object):
|
|||||||
decoder_filename=decoder,
|
decoder_filename=decoder,
|
||||||
joiner_filename=joiner,
|
joiner_filename=joiner,
|
||||||
),
|
),
|
||||||
paraformer=OfflineParaformerModelConfig(model=""),
|
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
@@ -119,7 +121,7 @@ class OfflineRecognizer(object):
|
|||||||
symbol integer_id
|
symbol integer_id
|
||||||
|
|
||||||
paraformer:
|
paraformer:
|
||||||
Path to ``paraformer.onnx``.
|
Path to ``model.onnx``.
|
||||||
num_threads:
|
num_threads:
|
||||||
Number of threads for neural network computation.
|
Number of threads for neural network computation.
|
||||||
sample_rate:
|
sample_rate:
|
||||||
@@ -133,9 +135,6 @@ class OfflineRecognizer(object):
|
|||||||
"""
|
"""
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
model_config = OfflineModelConfig(
|
model_config = OfflineModelConfig(
|
||||||
transducer=OfflineTransducerModelConfig(
|
|
||||||
encoder_filename="", decoder_filename="", joiner_filename=""
|
|
||||||
),
|
|
||||||
paraformer=OfflineParaformerModelConfig(model=paraformer),
|
paraformer=OfflineParaformerModelConfig(model=paraformer),
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
@@ -155,6 +154,64 @@ class OfflineRecognizer(object):
|
|||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_nemo_ctc(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
feature_dim: int = 80,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||||
|
to download pre-trained models for different languages, e.g., Chinese,
|
||||||
|
English, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
model:
|
||||||
|
Path to ``model.onnx``.
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
sample_rate:
|
||||||
|
Sample rate of the training data used to train the model.
|
||||||
|
feature_dim:
|
||||||
|
Dimension of the feature used to train the model.
|
||||||
|
decoding_method:
|
||||||
|
Valid values are greedy_search, modified_beam_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = OfflineFeatureExtractorConfig(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
return self
|
||||||
|
|
||||||
def create_stream(self):
|
def create_stream(self):
|
||||||
return self.recognizer.create_stream()
|
return self.recognizer.create_stream()
|
||||||
|
|
||||||
|
|||||||
@@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
print(s2.result.text)
|
print(s2.result.text)
|
||||||
print(s3.result.text)
|
print(s3.result.text)
|
||||||
|
|
||||||
|
def test_nemo_ctc_single_file(self):
|
||||||
|
for use_int8 in [True, False]:
|
||||||
|
if use_int8:
|
||||||
|
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
|
||||||
|
else:
|
||||||
|
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
|
||||||
|
|
||||||
|
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
|
||||||
|
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
|
||||||
|
|
||||||
|
if not Path(model).is_file():
|
||||||
|
print("skipping test_nemo_ctc_single_file()")
|
||||||
|
return
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||||
|
model=model,
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
samples, sample_rate = read_wave(wave0)
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
recognizer.decode_stream(s)
|
||||||
|
print(s.result.text)
|
||||||
|
|
||||||
|
def test_nemo_ctc_multiple_files(self):
|
||||||
|
for use_int8 in [True, False]:
|
||||||
|
if use_int8:
|
||||||
|
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
|
||||||
|
else:
|
||||||
|
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
|
||||||
|
|
||||||
|
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
|
||||||
|
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
|
||||||
|
wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav"
|
||||||
|
wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
|
||||||
|
|
||||||
|
if not Path(model).is_file():
|
||||||
|
print("skipping test_nemo_ctc_multiple_files()")
|
||||||
|
return
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||||
|
model=model,
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
s0 = recognizer.create_stream()
|
||||||
|
samples0, sample_rate0 = read_wave(wave0)
|
||||||
|
s0.accept_waveform(sample_rate0, samples0)
|
||||||
|
|
||||||
|
s1 = recognizer.create_stream()
|
||||||
|
samples1, sample_rate1 = read_wave(wave1)
|
||||||
|
s1.accept_waveform(sample_rate1, samples1)
|
||||||
|
|
||||||
|
s2 = recognizer.create_stream()
|
||||||
|
samples2, sample_rate2 = read_wave(wave2)
|
||||||
|
s2.accept_waveform(sample_rate2, samples2)
|
||||||
|
|
||||||
|
recognizer.decode_streams([s0, s1, s2])
|
||||||
|
print(s0.result.text)
|
||||||
|
print(s1.result.text)
|
||||||
|
print(s2.result.text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user