Support TDNN models from the yesno recipe from icefall (#262)
This commit is contained in:
44
.github/scripts/test-offline-ctc.sh
vendored
44
.github/scripts/test-offline-ctc.sh
vendored
@@ -13,6 +13,50 @@ echo "PATH: $PATH"
|
|||||||
|
|
||||||
which $EXE
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run tdnn yesno (Hebrew)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "test float32 models"
|
||||||
|
time $EXE \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feat-dim=23 \
|
||||||
|
\
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--tdnn-model=$repo/model-epoch-14-avg-2.onnx \
|
||||||
|
$repo/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_0_1_1_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_1_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_1_0_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_1_0_1_1_0.wav
|
||||||
|
|
||||||
|
log "test int8 models"
|
||||||
|
time $EXE \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feat-dim=23 \
|
||||||
|
\
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--tdnn-model=$repo/model-epoch-14-avg-2.int8.onnx \
|
||||||
|
$repo/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_0_1_1_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_0_1_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_1_0_0_0_1.wav \
|
||||||
|
$repo/test_wavs/0_0_1_1_0_1_1_0.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
log "Run Citrinet (stt_en_citrinet_512, English)"
|
log "Run Citrinet (stt_en_citrinet_512, English)"
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||||
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"]
|
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper", "tdnn"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
@@ -172,3 +172,41 @@ jobs:
|
|||||||
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
|
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
|
||||||
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
|
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
|
||||||
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
|
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
- name: Start server for tdnn models
|
||||||
|
if: matrix.model_type == 'tdnn'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
|
||||||
|
cd sherpa-onnx-tdnn-yesno
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
python3 ./python-api-examples/non_streaming_server.py \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feat-dim=23 &
|
||||||
|
|
||||||
|
echo "sleep 10 seconds to wait the server start"
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
- name: Start client for tdnn models
|
||||||
|
if: matrix.model_type == 'tdnn'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.7.2")
|
set(SHERPA_ONNX_VERSION "1.7.3")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \
|
|||||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||||
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
||||||
|
|
||||||
|
(5) Use a tdnn model of the yesno recipe from icefall
|
||||||
|
|
||||||
|
cd /path/to/sherpa-onnx
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
|
||||||
|
cd sherpa-onnx-tdnn-yesno
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
|
||||||
|
python3 ./python-api-examples/non_streaming_server.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feat-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
To use a certificate so that you can use https, please use
|
To use a certificate so that you can use https, please use
|
||||||
@@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--whisper-encoder",
|
"--whisper-encoder",
|
||||||
@@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||||||
add_transducer_model_args(parser)
|
add_transducer_model_args(parser)
|
||||||
add_paraformer_model_args(parser)
|
add_paraformer_model_args(parser)
|
||||||
add_nemo_ctc_model_args(parser)
|
add_nemo_ctc_model_args(parser)
|
||||||
|
add_tdnn_ctc_model_args(parser)
|
||||||
add_whisper_model_args(parser)
|
add_whisper_model_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
@@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
@@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
elif args.nemo_ctc:
|
elif args.nemo_ctc:
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
assert_file_exists(args.nemo_ctc)
|
assert_file_exists(args.nemo_ctc)
|
||||||
|
|
||||||
@@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
)
|
)
|
||||||
elif args.whisper_encoder:
|
elif args.whisper_encoder:
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
assert_file_exists(args.whisper_encoder)
|
assert_file_exists(args.whisper_encoder)
|
||||||
assert_file_exists(args.whisper_decoder)
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
|
||||||
@@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
)
|
)
|
||||||
|
elif args.tdnn_model:
|
||||||
|
assert_file_exists(args.tdnn_model)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||||||
|
model=args.tdnn_model,
|
||||||
|
tokens=args.tokens,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feat_dim,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Please specify at least one model")
|
raise ValueError("Please specify at least one model")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
|
|||||||
file(s) with a non-streaming model.
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
(1) For paraformer
|
(1) For paraformer
|
||||||
|
|
||||||
./python-api-examples/offline-decode-files.py \
|
./python-api-examples/offline-decode-files.py \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--paraformer=/path/to/paraformer.onnx \
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
@@ -20,6 +21,7 @@ file(s) with a non-streaming model.
|
|||||||
/path/to/1.wav
|
/path/to/1.wav
|
||||||
|
|
||||||
(2) For transducer models from icefall
|
(2) For transducer models from icefall
|
||||||
|
|
||||||
./python-api-examples/offline-decode-files.py \
|
./python-api-examples/offline-decode-files.py \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--encoder=/path/to/encoder.onnx \
|
--encoder=/path/to/encoder.onnx \
|
||||||
@@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \
|
|||||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) For tdnn models of the yesno recipe from icefall
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feature-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
to install sherpa-onnx and to download the pre-trained models
|
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||||
used in this file.
|
used in this file.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
@@ -159,6 +172,13 @@ def get_args():
|
|||||||
help="Path to the model.onnx from NeMo CTC",
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-threads",
|
"--num-threads",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -285,6 +305,7 @@ def main():
|
|||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||||
if contexts:
|
if contexts:
|
||||||
@@ -311,6 +332,7 @@ def main():
|
|||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
@@ -326,6 +348,7 @@ def main():
|
|||||||
elif args.nemo_ctc:
|
elif args.nemo_ctc:
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
assert_file_exists(args.nemo_ctc)
|
assert_file_exists(args.nemo_ctc)
|
||||||
|
|
||||||
@@ -339,6 +362,7 @@ def main():
|
|||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
elif args.whisper_encoder:
|
elif args.whisper_encoder:
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
assert_file_exists(args.whisper_encoder)
|
assert_file_exists(args.whisper_encoder)
|
||||||
assert_file_exists(args.whisper_decoder)
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
|
||||||
@@ -347,6 +371,20 @@ def main():
|
|||||||
decoder=args.whisper_decoder,
|
decoder=args.whisper_decoder,
|
||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
elif args.tdnn_model:
|
||||||
|
assert_file_exists(args.tdnn_model)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||||||
|
model=args.tdnn_model,
|
||||||
|
tokens=args.tokens,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
num_threads=args.num_threads,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,20 +97,18 @@ function onFileChange() {
|
|||||||
console.log('file.type ' + file.type);
|
console.log('file.type ' + file.type);
|
||||||
console.log('file.size ' + file.size);
|
console.log('file.size ' + file.size);
|
||||||
|
|
||||||
|
let audioCtx = new AudioContext({sampleRate: 16000});
|
||||||
|
|
||||||
let reader = new FileReader();
|
let reader = new FileReader();
|
||||||
reader.onload = function() {
|
reader.onload = function() {
|
||||||
console.log('reading file!');
|
console.log('reading file!');
|
||||||
let view = new Int16Array(reader.result);
|
audioCtx.decodeAudioData(reader.result, decodedDone);
|
||||||
// we assume the input file is a wav file.
|
};
|
||||||
// TODO: add some checks here.
|
|
||||||
let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts
|
|
||||||
let num_samples = int16_samples.length;
|
|
||||||
let float32_samples = new Float32Array(num_samples);
|
|
||||||
console.log('num_samples ' + num_samples)
|
|
||||||
|
|
||||||
for (let i = 0; i < num_samples; ++i) {
|
function decodedDone(decoded) {
|
||||||
float32_samples[i] = int16_samples[i] / 32768.
|
let typedArray = new Float32Array(decoded.length);
|
||||||
}
|
let float32_samples = decoded.getChannelData(0);
|
||||||
|
let buf = float32_samples.buffer
|
||||||
|
|
||||||
// Send 1024 audio samples per request.
|
// Send 1024 audio samples per request.
|
||||||
//
|
//
|
||||||
@@ -119,14 +117,13 @@ function onFileChange() {
|
|||||||
// (2) There is a limit on the number of bytes in the payload that can be
|
// (2) There is a limit on the number of bytes in the payload that can be
|
||||||
// sent by websocket, which is 1MB, I think. We can send a large
|
// sent by websocket, which is 1MB, I think. We can send a large
|
||||||
// audio file for decoding in this approach.
|
// audio file for decoding in this approach.
|
||||||
let buf = float32_samples.buffer
|
|
||||||
let n = 1024 * 4; // send this number of bytes per request.
|
let n = 1024 * 4; // send this number of bytes per request.
|
||||||
console.log('buf length, ' + buf.byteLength);
|
console.log('buf length, ' + buf.byteLength);
|
||||||
send_header(buf.byteLength);
|
send_header(buf.byteLength);
|
||||||
for (let start = 0; start < buf.byteLength; start += n) {
|
for (let start = 0; start < buf.byteLength; start += n) {
|
||||||
socket.send(buf.slice(start, start + n));
|
socket.send(buf.slice(start, start + n));
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
reader.readAsArrayBuffer(file);
|
reader.readAsArrayBuffer(file);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ set(sources
|
|||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
offline-rnn-lm.cc
|
offline-rnn-lm.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
|
offline-tdnn-ctc-model.cc
|
||||||
|
offline-tdnn-model-config.cc
|
||||||
offline-transducer-greedy-search-decoder.cc
|
offline-transducer-greedy-search-decoder.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-transducer-model.cc
|
offline-transducer-model.cc
|
||||||
|
|||||||
@@ -11,12 +11,14 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
enum class ModelType {
|
enum class ModelType {
|
||||||
kEncDecCTCModelBPE,
|
kEncDecCTCModelBPE,
|
||||||
|
kTdnn,
|
||||||
kUnkown,
|
kUnkown,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
|
|
||||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||||
return ModelType::kEncDecCTCModelBPE;
|
return ModelType::kEncDecCTCModelBPE;
|
||||||
|
} else if (model_type.get() == std::string("tdnn")) {
|
||||||
|
return ModelType::kTdnn;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
@@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
const OfflineModelConfig &config) {
|
const OfflineModelConfig &config) {
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnkown;
|
||||||
|
|
||||||
|
std::string filename;
|
||||||
|
if (!config.nemo_ctc.model.empty()) {
|
||||||
|
filename = config.nemo_ctc.model;
|
||||||
|
} else if (!config.tdnn.model.empty()) {
|
||||||
|
filename = config.tdnn.model;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buffer = ReadFile(config.nemo_ctc.model);
|
auto buffer = ReadFile(filename);
|
||||||
|
|
||||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
}
|
}
|
||||||
@@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
case ModelType::kEncDecCTCModelBPE:
|
case ModelType::kEncDecCTCModelBPE:
|
||||||
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
|
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
|
||||||
break;
|
break;
|
||||||
|
case ModelType::kTdnn:
|
||||||
|
return std::make_unique<OfflineTdnnCtcModel>(config);
|
||||||
|
break;
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@@ -39,10 +39,10 @@ class OfflineCtcModel {
|
|||||||
|
|
||||||
/** SubsamplingFactor of the model
|
/** SubsamplingFactor of the model
|
||||||
*
|
*
|
||||||
* For Citrinet, the subsampling factor is usually 4.
|
* For NeMo Citrinet, the subsampling factor is usually 4.
|
||||||
* For Conformer CTC, the subsampling factor is usually 8.
|
* For NeMo Conformer CTC, the subsampling factor is usually 8.
|
||||||
*/
|
*/
|
||||||
virtual int32_t SubsamplingFactor() const = 0;
|
virtual int32_t SubsamplingFactor() const { return 1; }
|
||||||
|
|
||||||
/** Return an allocator for allocating memory
|
/** Return an allocator for allocating memory
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
paraformer.Register(po);
|
paraformer.Register(po);
|
||||||
nemo_ctc.Register(po);
|
nemo_ctc.Register(po);
|
||||||
whisper.Register(po);
|
whisper.Register(po);
|
||||||
|
tdnn.Register(po);
|
||||||
|
|
||||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||||
|
|
||||||
@@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
po->Register("model-type", &model_type,
|
po->Register("model-type", &model_type,
|
||||||
"Specify it to reduce model initialization time. "
|
"Specify it to reduce model initialization time. "
|
||||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper."
|
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||||
|
"tdnn."
|
||||||
"All other values lead to loading the model twice.");
|
"All other values lead to loading the model twice.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return whisper.Validate();
|
return whisper.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!tdnn.model.empty()) {
|
||||||
|
return tdnn.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
return transducer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||||
os << "whisper=" << whisper.ToString() << ", ";
|
os << "whisper=" << whisper.ToString() << ", ";
|
||||||
|
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineParaformerModelConfig paraformer;
|
OfflineParaformerModelConfig paraformer;
|
||||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||||
OfflineWhisperModelConfig whisper;
|
OfflineWhisperModelConfig whisper;
|
||||||
|
OfflineTdnnModelConfig tdnn;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 2;
|
int32_t num_threads = 2;
|
||||||
@@ -40,12 +42,14 @@ struct OfflineModelConfig {
|
|||||||
const OfflineParaformerModelConfig ¶former,
|
const OfflineParaformerModelConfig ¶former,
|
||||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||||
const OfflineWhisperModelConfig &whisper,
|
const OfflineWhisperModelConfig &whisper,
|
||||||
|
const OfflineTdnnModelConfig &tdnn,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type)
|
const std::string &provider, const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
nemo_ctc(nemo_ctc),
|
nemo_ctc(nemo_ctc),
|
||||||
whisper(whisper),
|
whisper(whisper),
|
||||||
|
tdnn(tdnn),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
|||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||||
|
if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
|
||||||
|
// tdnn models from yesno have a SIL token, we should remove it.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto sym = sym_table[src.tokens[i]];
|
auto sym = sym_table[src.tokens[i]];
|
||||||
text.append(sym);
|
text.append(sym);
|
||||||
r.tokens.push_back(std::move(sym));
|
r.tokens.push_back(std::move(sym));
|
||||||
@@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
model_->FeatureNormalizationMethod();
|
model_->FeatureNormalizationMethod();
|
||||||
|
|
||||||
if (config.decoding_method == "greedy_search") {
|
if (config.decoding_method == "greedy_search") {
|
||||||
if (!symbol_table_.contains("<blk>")) {
|
if (!symbol_table_.contains("<blk>") &&
|
||||||
|
!symbol_table_.contains("<eps>")) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"We expect that tokens.txt contains "
|
"We expect that tokens.txt contains "
|
||||||
"the symbol <blk> and its ID.");
|
"the symbol <blk> or <eps> and its ID.");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t blank_id = symbol_table_["<blk>"];
|
int32_t blank_id = 0;
|
||||||
|
if (symbol_table_.contains("<blk>")) {
|
||||||
|
blank_id = symbol_table_["<blk>"];
|
||||||
|
} else if (symbol_table_.contains("<eps>")) {
|
||||||
|
// for tdnn models of the yesno recipe from icefall
|
||||||
|
blank_id = symbol_table_["<eps>"];
|
||||||
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
} else if (model_type == "nemo_ctc") {
|
} else if (model_type == "nemo_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
|
} else if (model_type == "tdnn") {
|
||||||
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
} else if (model_type == "whisper") {
|
} else if (model_type == "whisper") {
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
} else {
|
} else {
|
||||||
@@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_filename = config.model_config.paraformer.model;
|
model_filename = config.model_config.paraformer.model;
|
||||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||||
model_filename = config.model_config.nemo_ctc.model;
|
model_filename = config.model_config.nemo_ctc.model;
|
||||||
|
} else if (!config.model_config.tdnn.model.empty()) {
|
||||||
|
model_filename = config.model_config.tdnn.model;
|
||||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||||
model_filename = config.model_config.whisper.encoder;
|
model_filename = config.model_config.whisper.encoder;
|
||||||
} else {
|
} else {
|
||||||
@@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||||
"\n "
|
"\n "
|
||||||
"(3) Whisper"
|
"(3) Whisper"
|
||||||
|
"\n "
|
||||||
|
"(4) Tdnn models of the yesno recipe from icefall"
|
||||||
|
"\n "
|
||||||
|
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
|
||||||
|
"\n"
|
||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
@@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_type == "tdnn") {
|
||||||
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
|
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
}
|
}
|
||||||
@@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
" - Non-streaming transducer models from icefall\n"
|
" - Non-streaming transducer models from icefall\n"
|
||||||
" - Non-streaming Paraformer models from FunASR\n"
|
" - Non-streaming Paraformer models from FunASR\n"
|
||||||
" - EncDecCTCModelBPE models from NeMo\n"
|
" - EncDecCTCModelBPE models from NeMo\n"
|
||||||
" - Whisper models\n",
|
" - Whisper models\n"
|
||||||
|
" - Tdnn models\n",
|
||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
|
|||||||
106
sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
Normal file
106
sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineTdnnCtcModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
|
||||||
|
auto nnet_out =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
|
||||||
|
std::vector<int64_t> nnet_out_shape =
|
||||||
|
nnet_out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::vector<int64_t> out_length_vec(nnet_out_shape[0], nnet_out_shape[1]);
|
||||||
|
std::vector<int64_t> out_length_shape(1, nnet_out_shape[0]);
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
Ort::Value nnet_out_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, out_length_vec.data(), out_length_vec.size(),
|
||||||
|
out_length_shape.data(), out_length_shape.size());
|
||||||
|
|
||||||
|
return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)};
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init() {
|
||||||
|
auto buf = ReadFile(config_.tdnn.model);
|
||||||
|
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
|
||||||
|
Ort::Value features, Ort::Value /*features_length*/) {
|
||||||
|
return impl_->Forward(std::move(features));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); }
|
||||||
|
|
||||||
|
OrtAllocator *OfflineTdnnCtcModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
56
sherpa-onnx/csrc/offline-tdnn-ctc-model.h
Normal file
56
sherpa-onnx/csrc/offline-tdnn-ctc-model.h
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-tdnn-ctc-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** This class implements the tdnn model of the yesno recipe from icefall.
|
||||||
|
*
|
||||||
|
* See
|
||||||
|
* https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
|
||||||
|
*/
|
||||||
|
class OfflineTdnnCtcModel : public OfflineCtcModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
|
||||||
|
~OfflineTdnnCtcModel() override;
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int64_t.
|
||||||
|
*
|
||||||
|
* @return Return a pair containing:
|
||||||
|
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||||
|
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> Forward(
|
||||||
|
Ort::Value features, Ort::Value /*features_length*/) override;
|
||||||
|
|
||||||
|
/** Return the vocabulary size of the model
|
||||||
|
*/
|
||||||
|
int32_t VocabSize() const override;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
|
||||||
34
sherpa-onnx/csrc/offline-tdnn-model-config.cc
Normal file
34
sherpa-onnx/csrc/offline-tdnn-model-config.cc
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-tdnn-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineTdnnModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("tdnn-model", &model, "Path to onnx model");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineTdnnModelConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("tdnn model file %s does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineTdnnModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineTdnnModelConfig(";
|
||||||
|
os << "model=\"" << model << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
28
sherpa-onnx/csrc/offline-tdnn-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-tdnn-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-tdnn-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
|
||||||
|
struct OfflineTdnnModelConfig {
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
OfflineTdnnModelConfig() = default;
|
||||||
|
explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
@@ -14,10 +14,14 @@
|
|||||||
|
|
||||||
int main(int32_t argc, char *argv[]) {
|
int main(int32_t argc, char *argv[]) {
|
||||||
const char *kUsageMessage = R"usage(
|
const char *kUsageMessage = R"usage(
|
||||||
|
Speech recognition using non-streaming models with sherpa-onnx.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
(1) Transducer from icefall
|
(1) Transducer from icefall
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
|
||||||
|
|
||||||
./bin/sherpa-onnx-offline \
|
./bin/sherpa-onnx-offline \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--encoder=/path/to/encoder.onnx \
|
--encoder=/path/to/encoder.onnx \
|
||||||
@@ -30,6 +34,8 @@ Usage:
|
|||||||
|
|
||||||
(2) Paraformer from FunASR
|
(2) Paraformer from FunASR
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
|
||||||
|
|
||||||
./bin/sherpa-onnx-offline \
|
./bin/sherpa-onnx-offline \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--paraformer=/path/to/model.onnx \
|
--paraformer=/path/to/model.onnx \
|
||||||
@@ -39,6 +45,8 @@ Usage:
|
|||||||
|
|
||||||
(3) Whisper models
|
(3) Whisper models
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
||||||
|
|
||||||
./bin/sherpa-onnx-offline \
|
./bin/sherpa-onnx-offline \
|
||||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
@@ -46,6 +54,31 @@ Usage:
|
|||||||
--num-threads=1 \
|
--num-threads=1 \
|
||||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
|
(4) NeMo CTC models
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
||||||
|
|
||||||
|
./bin/sherpa-onnx-offline \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \
|
||||||
|
--nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) TDNN CTC model for the yesno recipe from icefall
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
|
||||||
|
//
|
||||||
|
./build/bin/sherpa-onnx-offline \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feat-dim=23 \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
Note: It supports decoding multiple files in batches
|
Note: It supports decoding multiple files in batches
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
|
offline-tdnn-model-config.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-whisper-model-config.cc
|
offline-whisper-model-config.cc
|
||||||
online-lm-config.cc
|
online-lm-config.cc
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
@@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineParaformerModelConfig(m);
|
PybindOfflineParaformerModelConfig(m);
|
||||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||||
PybindOfflineWhisperModelConfig(m);
|
PybindOfflineWhisperModelConfig(m);
|
||||||
|
PybindOfflineTdnnModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
.def(py::init<const OfflineTransducerModelConfig &,
|
.def(py::init<const OfflineTransducerModelConfig &,
|
||||||
const OfflineParaformerModelConfig &,
|
const OfflineParaformerModelConfig &,
|
||||||
const OfflineNemoEncDecCtcModelConfig &,
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
const OfflineWhisperModelConfig &, const std::string &,
|
const OfflineWhisperModelConfig &,
|
||||||
|
const OfflineTdnnModelConfig &, const std::string &,
|
||||||
int32_t, bool, const std::string &, const std::string &>(),
|
int32_t, bool, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||||
py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
|
py::arg("whisper") = OfflineWhisperModelConfig(),
|
||||||
|
py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"),
|
||||||
py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("num_threads"), py::arg("debug") = false,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||||
.def_readwrite("whisper", &PyClass::whisper)
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
|
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
22
sherpa-onnx/python/csrc/offline-tdnn-model-config.cc
Normal file
22
sherpa-onnx/python/csrc/offline-tdnn-model-config.cc
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-tdnn-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineTdnnModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineTdnnModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineTdnnModelConfig")
|
||||||
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-tdnn-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-tdnn-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-tdnn-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineTdnnModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
|
||||||
@@ -8,6 +8,7 @@ from _sherpa_onnx import (
|
|||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
OfflineNemoEncDecCtcModelConfig,
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
OfflineParaformerModelConfig,
|
||||||
|
OfflineTdnnModelConfig,
|
||||||
OfflineWhisperModelConfig,
|
OfflineWhisperModelConfig,
|
||||||
)
|
)
|
||||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||||
@@ -37,7 +38,7 @@ class OfflineRecognizer(object):
|
|||||||
decoder: str,
|
decoder: str,
|
||||||
joiner: str,
|
joiner: str,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
num_threads: int,
|
num_threads: int = 1,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
@@ -48,7 +49,7 @@ class OfflineRecognizer(object):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html>`_
|
||||||
to download pre-trained models for different languages, e.g., Chinese,
|
to download pre-trained models for different languages, e.g., Chinese,
|
||||||
English, etc.
|
English, etc.
|
||||||
|
|
||||||
@@ -115,7 +116,7 @@ class OfflineRecognizer(object):
|
|||||||
cls,
|
cls,
|
||||||
paraformer: str,
|
paraformer: str,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
num_threads: int,
|
num_threads: int = 1,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
@@ -124,9 +125,8 @@ class OfflineRecognizer(object):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html>`_
|
||||||
to download pre-trained models for different languages, e.g., Chinese,
|
to download pre-trained models.
|
||||||
English, etc.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens:
|
tokens:
|
||||||
@@ -179,7 +179,7 @@ class OfflineRecognizer(object):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
num_threads: int,
|
num_threads: int = 1,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
@@ -188,7 +188,7 @@ class OfflineRecognizer(object):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html>`_
|
||||||
to download pre-trained models for different languages, e.g., Chinese,
|
to download pre-trained models for different languages, e.g., Chinese,
|
||||||
English, etc.
|
English, etc.
|
||||||
|
|
||||||
@@ -244,14 +244,14 @@ class OfflineRecognizer(object):
|
|||||||
encoder: str,
|
encoder: str,
|
||||||
decoder: str,
|
decoder: str,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
num_threads: int,
|
num_threads: int = 1,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
|
||||||
to download pre-trained models for different kinds of whisper models,
|
to download pre-trained models for different kinds of whisper models,
|
||||||
e.g., tiny, tiny.en, base, base.en, etc.
|
e.g., tiny, tiny.en, base, base.en, etc.
|
||||||
|
|
||||||
@@ -301,6 +301,69 @@ class OfflineRecognizer(object):
|
|||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tdnn_ctc(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int = 1,
|
||||||
|
sample_rate: int = 8000,
|
||||||
|
feature_dim: int = 23,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html>`_
|
||||||
|
to download pre-trained models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
Path to ``model.onnx``.
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
sample_rate:
|
||||||
|
Sample rate of the training data used to train the model.
|
||||||
|
feature_dim:
|
||||||
|
Dimension of the feature used to train the model.
|
||||||
|
decoding_method:
|
||||||
|
Valid values are greedy_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
tdnn=OfflineTdnnModelConfig(model=model),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
|
model_type="tdnn",
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = OfflineFeatureExtractorConfig(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
self.config = recognizer_config
|
||||||
|
return self
|
||||||
|
|
||||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||||
if contexts_list is None:
|
if contexts_list is None:
|
||||||
return self.recognizer.create_stream()
|
return self.recognizer.create_stream()
|
||||||
|
|||||||
Reference in New Issue
Block a user