Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)
This commit is contained in:
99
.github/scripts/test-offline-transducer.sh
vendored
99
.github/scripts/test-offline-transducer.sh
vendored
@@ -13,6 +13,105 @@ echo "PATH: $PATH"
|
|||||||
|
|
||||||
which $EXE
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------------------"
|
||||||
|
log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)"
|
||||||
|
log "------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2
|
||||||
|
name=$(basename $url)
|
||||||
|
curl -SL -O $url
|
||||||
|
tar xvf $name
|
||||||
|
rm $name
|
||||||
|
repo=$(basename -s .tar.bz2 $name)
|
||||||
|
ls -lh $repo
|
||||||
|
|
||||||
|
log "test $repo"
|
||||||
|
test_wavs=(
|
||||||
|
de-german.wav
|
||||||
|
es-spanish.wav
|
||||||
|
hr-croatian.wav
|
||||||
|
po-polish.wav
|
||||||
|
uk-ukrainian.wav
|
||||||
|
en-english.wav
|
||||||
|
fr-french.wav
|
||||||
|
it-italian.wav
|
||||||
|
ru-russian.wav
|
||||||
|
)
|
||||||
|
for w in ${test_wavs[@]}; do
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder.onnx \
|
||||||
|
--decoder=$repo/decoder.onnx \
|
||||||
|
--joiner=$repo/joiner.onnx \
|
||||||
|
--debug=1 \
|
||||||
|
$repo/test_wavs/$w
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2
|
||||||
|
name=$(basename $url)
|
||||||
|
curl -SL -O $url
|
||||||
|
tar xvf $name
|
||||||
|
rm $name
|
||||||
|
repo=$(basename -s .tar.bz2 $name)
|
||||||
|
ls -lh $repo
|
||||||
|
|
||||||
|
log "Test $repo"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder.onnx \
|
||||||
|
--decoder=$repo/decoder.onnx \
|
||||||
|
--joiner=$repo/joiner.onnx \
|
||||||
|
--debug=1 \
|
||||||
|
$repo/test_wavs/en-english.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2
|
||||||
|
name=$(basename $url)
|
||||||
|
curl -SL -O $url
|
||||||
|
tar xvf $name
|
||||||
|
rm $name
|
||||||
|
repo=$(basename -s .tar.bz2 $name)
|
||||||
|
ls -lh $repo
|
||||||
|
|
||||||
|
log "test $repo"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder.onnx \
|
||||||
|
--decoder=$repo/decoder.onnx \
|
||||||
|
--joiner=$repo/joiner.onnx \
|
||||||
|
--debug=1 \
|
||||||
|
$repo/test_wavs/es-spanish.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2
|
||||||
|
name=$(basename $url)
|
||||||
|
curl -SL -O $url
|
||||||
|
tar xvf $name
|
||||||
|
rm $name
|
||||||
|
repo=$(basename -s .tar.bz2 $name)
|
||||||
|
ls -lh $repo
|
||||||
|
|
||||||
|
log "Test $repo"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder.onnx \
|
||||||
|
--decoder=$repo/decoder.onnx \
|
||||||
|
--joiner=$repo/joiner.onnx \
|
||||||
|
--debug=1 \
|
||||||
|
$repo/test_wavs/en-english.wav \
|
||||||
|
$repo/test_wavs/de-german.wav \
|
||||||
|
$repo/test_wavs/fr-french.wav \
|
||||||
|
$repo/test_wavs/es-spanish.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
log "Run Conformer transducer (English)"
|
log "Run Conformer transducer (English)"
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
|
|||||||
16
.github/workflows/linux.yaml
vendored
16
.github/workflows/linux.yaml
vendored
@@ -128,6 +128,14 @@ jobs:
|
|||||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||||
path: install/*
|
path: install/*
|
||||||
|
|
||||||
|
- name: Test offline transducer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-transducer.sh
|
||||||
|
|
||||||
- name: Test spoken language identification (C++ API)
|
- name: Test spoken language identification (C++ API)
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -215,14 +223,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-paraformer.sh
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline transducer
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
export PATH=$PWD/build/bin:$PATH
|
|
||||||
export EXE=sherpa-onnx-offline
|
|
||||||
|
|
||||||
.github/scripts/test-offline-transducer.sh
|
|
||||||
|
|
||||||
- name: Test online transducer
|
- name: Test online transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
16
.github/workflows/macos.yaml
vendored
16
.github/workflows/macos.yaml
vendored
@@ -107,6 +107,14 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline transducer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-transducer.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -192,14 +200,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-offline-ctc.sh
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline transducer
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
export PATH=$PWD/build/bin:$PATH
|
|
||||||
export EXE=sherpa-onnx-offline
|
|
||||||
|
|
||||||
.github/scripts/test-offline-transducer.sh
|
|
||||||
|
|
||||||
- name: Test online transducer
|
- name: Test online transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
|||||||
sherpa-onnx-ced-*
|
sherpa-onnx-ced-*
|
||||||
node_modules
|
node_modules
|
||||||
package-lock.json
|
package-lock.json
|
||||||
|
sherpa-onnx-nemo-*
|
||||||
|
|||||||
68
python-api-examples/offline-nemo-ctc-decode-files.py
Executable file
68
python-api-examples/offline-nemo-ctc-decode-files.py
Executable file
@@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming CTC model from NeMo
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
|
||||||
|
The example model supports 10 languages and it is converted from
|
||||||
|
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx"
|
||||||
|
tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
|
||||||
|
|
||||||
|
test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
|
||||||
|
|
||||||
|
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||||
|
model=model,
|
||||||
|
tokens=tokens,
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
print(wave_filename)
|
||||||
|
print(stream.result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
73
python-api-examples/offline-nemo-transducer-decode-files.py
Executable file
73
python-api-examples/offline-nemo-transducer-decode-files.py
Executable file
@@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming transducer model from NeMo
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
|
||||||
|
The example model supports 10 languages and it is converted from
|
||||||
|
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx"
|
||||||
|
decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx"
|
||||||
|
joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx"
|
||||||
|
tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
|
||||||
|
|
||||||
|
test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
|
||||||
|
|
||||||
|
if not Path(encoder).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
tokens=tokens,
|
||||||
|
model_type="nemo_transducer",
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
print(wave_filename)
|
||||||
|
print(stream.result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -40,9 +40,11 @@ set(sources
|
|||||||
offline-tdnn-ctc-model.cc
|
offline-tdnn-ctc-model.cc
|
||||||
offline-tdnn-model-config.cc
|
offline-tdnn-model-config.cc
|
||||||
offline-transducer-greedy-search-decoder.cc
|
offline-transducer-greedy-search-decoder.cc
|
||||||
|
offline-transducer-greedy-search-nemo-decoder.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-transducer-model.cc
|
offline-transducer-model.cc
|
||||||
offline-transducer-modified-beam-search-decoder.cc
|
offline-transducer-modified-beam-search-decoder.cc
|
||||||
|
offline-transducer-nemo-model.cc
|
||||||
offline-wenet-ctc-model-config.cc
|
offline-wenet-ctc-model-config.cc
|
||||||
offline-wenet-ctc-model.cc
|
offline-wenet-ctc-model.cc
|
||||||
offline-whisper-greedy-search-decoder.cc
|
offline-whisper-greedy-search-decoder.cc
|
||||||
|
|||||||
@@ -56,6 +56,19 @@ struct FeatureExtractorConfig {
|
|||||||
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
|
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
|
||||||
std::string window_type = "povey"; // e.g. Hamming window
|
std::string window_type = "povey"; // e.g. Hamming window
|
||||||
|
|
||||||
|
// For models from NeMo
|
||||||
|
// This option is not exposed and is set internally when loading models.
|
||||||
|
// Possible values:
|
||||||
|
// - per_feature
|
||||||
|
// - all_features (not implemented yet)
|
||||||
|
// - fixed_mean (not implemented)
|
||||||
|
// - fixed_std (not implemented)
|
||||||
|
// - or just leave it to empty
|
||||||
|
// See
|
||||||
|
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||||
|
// for details
|
||||||
|
std::string nemo_normalize_type;
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
|||||||
: config_(config),
|
: config_(config),
|
||||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
sym_(config.model_config.tokens) {
|
sym_(config.model_config.tokens) {
|
||||||
if (sym_.contains("<unk>")) {
|
if (sym_.Contains("<unk>")) {
|
||||||
unk_id_ = sym_["<unk>"];
|
unk_id_ = sym_["<unk>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
|||||||
: config_(config),
|
: config_(config),
|
||||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||||
sym_(mgr, config.model_config.tokens) {
|
sym_(mgr, config.model_config.tokens) {
|
||||||
if (sym_.contains("<unk>")) {
|
if (sym_.Contains("<unk>")) {
|
||||||
unk_id_ = sym_["<unk>"];
|
unk_id_ = sym_["<unk>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
|
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
|||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||||
if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
|
if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
|
||||||
// tdnn models from yesno have a SIL token, we should remove it.
|
// tdnn models from yesno have a SIL token, we should remove it.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
|
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
|
||||||
config_.ctc_fst_decoder_config);
|
config_.ctc_fst_decoder_config);
|
||||||
} else if (config_.decoding_method == "greedy_search") {
|
} else if (config_.decoding_method == "greedy_search") {
|
||||||
if (!symbol_table_.contains("<blk>") &&
|
if (!symbol_table_.Contains("<blk>") &&
|
||||||
!symbol_table_.contains("<eps>") &&
|
!symbol_table_.Contains("<eps>") &&
|
||||||
!symbol_table_.contains("<blank>")) {
|
!symbol_table_.Contains("<blank>")) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"We expect that tokens.txt contains "
|
"We expect that tokens.txt contains "
|
||||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||||
@@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t blank_id = 0;
|
int32_t blank_id = 0;
|
||||||
if (symbol_table_.contains("<blk>")) {
|
if (symbol_table_.Contains("<blk>")) {
|
||||||
blank_id = symbol_table_["<blk>"];
|
blank_id = symbol_table_["<blk>"];
|
||||||
} else if (symbol_table_.contains("<eps>")) {
|
} else if (symbol_table_.Contains("<eps>")) {
|
||||||
// for tdnn models of the yesno recipe from icefall
|
// for tdnn models of the yesno recipe from icefall
|
||||||
blank_id = symbol_table_["<eps>"];
|
blank_id = symbol_table_["<eps>"];
|
||||||
} else if (symbol_table_.contains("<blank>")) {
|
} else if (symbol_table_.Contains("<blank>")) {
|
||||||
// for Wenet CTC models
|
// for Wenet CTC models
|
||||||
blank_id = symbol_table_["<blank>"];
|
blank_id = symbol_table_["<blank>"];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
@@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
const auto &model_type = config.model_config.model_type;
|
const auto &model_type = config.model_config.model_type;
|
||||||
if (model_type == "transducer") {
|
if (model_type == "transducer") {
|
||||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||||
|
} else if (model_type == "nemo_transducer") {
|
||||||
|
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
|
||||||
} else if (model_type == "paraformer") {
|
} else if (model_type == "paraformer") {
|
||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||||
@@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||||
|
!config.model_config.transducer.decoder_filename.empty() &&
|
||||||
|
!config.model_config.transducer.joiner_filename.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
if (model_type == "EncDecCTCModelBPE" ||
|
if (model_type == "EncDecCTCModelBPE" ||
|
||||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
@@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
const auto &model_type = config.model_config.model_type;
|
const auto &model_type = config.model_config.model_type;
|
||||||
if (model_type == "transducer") {
|
if (model_type == "transducer") {
|
||||||
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
||||||
|
} else if (model_type == "nemo_transducer") {
|
||||||
|
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||||
} else if (model_type == "paraformer") {
|
} else if (model_type == "paraformer") {
|
||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||||
@@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||||
|
!config.model_config.transducer.decoder_filename.empty() &&
|
||||||
|
!config.model_config.transducer.joiner_filename.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
if (model_type == "EncDecCTCModelBPE" ||
|
if (model_type == "EncDecCTCModelBPE" ||
|
||||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
|
|||||||
182
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Normal file
182
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <ios>
|
||||||
|
#include <memory>
|
||||||
|
#include <regex> // NOLINT
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
#include "sherpa-onnx/csrc/utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// defined in ./offline-recognizer-transducer-impl.h
|
||||||
|
OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table,
|
||||||
|
int32_t frame_shift_ms,
|
||||||
|
int32_t subsampling_factor);
|
||||||
|
|
||||||
|
class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||||
|
config_.model_config)) {
|
||||||
|
if (config_.decoding_method == "greedy_search") {
|
||||||
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
|
||||||
|
model_.get(), config_.blank_penalty);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
|
config_.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
PostInit();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||||
|
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
symbol_table_(mgr, config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||||
|
mgr, config_.model_config)) {
|
||||||
|
if (config_.decoding_method == "greedy_search") {
|
||||||
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
|
||||||
|
model_.get(), config_.blank_penalty);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
|
config_.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
PostInit();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = ss[0]->FeatureDim();
|
||||||
|
|
||||||
|
std::vector<Ort::Value> features;
|
||||||
|
|
||||||
|
features.reserve(n);
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> features_vec(n);
|
||||||
|
std::vector<int64_t> features_length_vec(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
auto f = ss[i]->GetFrames();
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
features_length_vec[i] = num_frames;
|
||||||
|
features_vec[i] = std::move(f);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
features.push_back(std::move(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const Ort::Value *> features_pointer(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
features_pointer[i] = &features[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int64_t, 1> features_length_shape = {n};
|
||||||
|
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||||
|
memory_info, features_length_vec.data(), n,
|
||||||
|
features_length_shape.data(), features_length_shape.size());
|
||||||
|
|
||||||
|
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||||
|
|
||||||
|
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||||
|
// t[0] encoder_out, float tensor, (batch_size, dim, T)
|
||||||
|
// t[1] encoder_out_length, int64 tensor, (batch_size,)
|
||||||
|
|
||||||
|
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
|
||||||
|
|
||||||
|
auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1]));
|
||||||
|
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||||
|
model_->SubsamplingFactor());
|
||||||
|
|
||||||
|
ss[i]->SetResult(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void PostInit() {
|
||||||
|
config_.feat_config.nemo_normalize_type =
|
||||||
|
model_->FeatureNormalizationMethod();
|
||||||
|
|
||||||
|
config_.feat_config.low_freq = 0;
|
||||||
|
// config_.feat_config.high_freq = 8000;
|
||||||
|
config_.feat_config.is_librosa = true;
|
||||||
|
config_.feat_config.remove_dc_offset = false;
|
||||||
|
// config_.feat_config.window_type = "hann";
|
||||||
|
config_.feat_config.dither = 0;
|
||||||
|
config_.feat_config.nemo_normalize_type =
|
||||||
|
model_->FeatureNormalizationMethod();
|
||||||
|
|
||||||
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
|
||||||
|
// check the blank ID
|
||||||
|
if (!symbol_table_.Contains("<blk>")) {
|
||||||
|
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol_table_["<blk>"] != vocab_size - 1) {
|
||||||
|
SHERPA_ONNX_LOGE("<blk> is not the last token!");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol_table_.NumSymbols() != vocab_size) {
|
||||||
|
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
|
||||||
|
symbol_table_.NumSymbols(), vocab_size);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineTransducerNeMoModel> model_;
|
||||||
|
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||||
@@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
|||||||
|
|
||||||
std::string text;
|
std::string text;
|
||||||
for (auto i : src.tokens) {
|
for (auto i : src.tokens) {
|
||||||
if (!sym_table.contains(i)) {
|
if (!sym_table.Contains(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
#include "android/asset_manager_jni.h"
|
#include "android/asset_manager_jni.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
@@ -26,7 +27,7 @@ namespace sherpa_onnx {
|
|||||||
struct OfflineRecognitionResult;
|
struct OfflineRecognitionResult;
|
||||||
|
|
||||||
struct OfflineRecognizerConfig {
|
struct OfflineRecognizerConfig {
|
||||||
OfflineFeatureExtractorConfig feat_config;
|
FeatureExtractorConfig feat_config;
|
||||||
OfflineModelConfig model_config;
|
OfflineModelConfig model_config;
|
||||||
OfflineLMConfig lm_config;
|
OfflineLMConfig lm_config;
|
||||||
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
|
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||||
@@ -44,7 +45,7 @@ struct OfflineRecognizerConfig {
|
|||||||
|
|
||||||
OfflineRecognizerConfig() = default;
|
OfflineRecognizerConfig() = default;
|
||||||
OfflineRecognizerConfig(
|
OfflineRecognizerConfig(
|
||||||
const OfflineFeatureExtractorConfig &feat_config,
|
const FeatureExtractorConfig &feat_config,
|
||||||
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
||||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||||
const std::string &decoding_method, int32_t max_active_paths,
|
const std::string &decoding_method, int32_t max_active_paths,
|
||||||
|
|||||||
@@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
|
||||||
po->Register("sample-rate", &sampling_rate,
|
|
||||||
"Sampling rate of the input waveform. "
|
|
||||||
"Note: You can have a different "
|
|
||||||
"sample rate for the input waveform. We will do resampling "
|
|
||||||
"inside the feature extractor");
|
|
||||||
|
|
||||||
po->Register("feat-dim", &feature_dim,
|
|
||||||
"Feature dimension. Must match the one expected by the model.");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string OfflineFeatureExtractorConfig::ToString() const {
|
|
||||||
std::ostringstream os;
|
|
||||||
|
|
||||||
os << "OfflineFeatureExtractorConfig(";
|
|
||||||
os << "sampling_rate=" << sampling_rate << ", ";
|
|
||||||
os << "feature_dim=" << feature_dim << ")";
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
class OfflineStream::Impl {
|
class OfflineStream::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
explicit Impl(const FeatureExtractorConfig &config,
|
||||||
ContextGraphPtr context_graph)
|
ContextGraphPtr context_graph)
|
||||||
: config_(config), context_graph_(context_graph) {
|
: config_(config), context_graph_(context_graph) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = config.dither;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||||
|
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||||
|
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||||
|
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||||
|
opts_.frame_opts.window_type = config.window_type;
|
||||||
|
|
||||||
opts_.mel_opts.num_bins = config.feature_dim;
|
opts_.mel_opts.num_bins = config.feature_dim;
|
||||||
|
|
||||||
// Please see
|
opts_.mel_opts.high_freq = config.high_freq;
|
||||||
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
|
opts_.mel_opts.low_freq = config.low_freq;
|
||||||
// and
|
|
||||||
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||||
opts_.mel_opts.high_freq = -400;
|
|
||||||
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
@@ -237,7 +220,7 @@ class OfflineStream::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineFeatureExtractorConfig config_;
|
FeatureExtractorConfig config_;
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
@@ -245,9 +228,8 @@ class OfflineStream::Impl {
|
|||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineStream::OfflineStream(
|
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||||
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
ContextGraphPtr context_graph /*= nullptr*/)
|
||||||
ContextGraphPtr context_graph /*= nullptr*/)
|
|
||||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||||
|
|
||||||
OfflineStream::OfflineStream(WhisperTag tag)
|
OfflineStream::OfflineStream(WhisperTag tag)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/context-graph.h"
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -32,46 +33,12 @@ struct OfflineRecognitionResult {
|
|||||||
std::string AsJsonString() const;
|
std::string AsJsonString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OfflineFeatureExtractorConfig {
|
|
||||||
// Sampling rate used by the feature extractor. If it is different from
|
|
||||||
// the sampling rate of the input waveform, we will do resampling inside.
|
|
||||||
int32_t sampling_rate = 16000;
|
|
||||||
|
|
||||||
// Feature dimension
|
|
||||||
int32_t feature_dim = 80;
|
|
||||||
|
|
||||||
// Set internally by some models, e.g., paraformer and wenet CTC models set
|
|
||||||
// it to false.
|
|
||||||
// This parameter is not exposed to users from the commandline
|
|
||||||
// If true, the feature extractor expects inputs to be normalized to
|
|
||||||
// the range [-1, 1].
|
|
||||||
// If false, we will multiply the inputs by 32768
|
|
||||||
bool normalize_samples = true;
|
|
||||||
|
|
||||||
// For models from NeMo
|
|
||||||
// This option is not exposed and is set internally when loading models.
|
|
||||||
// Possible values:
|
|
||||||
// - per_feature
|
|
||||||
// - all_features (not implemented yet)
|
|
||||||
// - fixed_mean (not implemented)
|
|
||||||
// - fixed_std (not implemented)
|
|
||||||
// - or just leave it to empty
|
|
||||||
// See
|
|
||||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
|
||||||
// for details
|
|
||||||
std::string nemo_normalize_type;
|
|
||||||
|
|
||||||
std::string ToString() const;
|
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WhisperTag {};
|
struct WhisperTag {};
|
||||||
struct CEDTag {};
|
struct CEDTag {};
|
||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
public:
|
public:
|
||||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
explicit OfflineStream(const FeatureExtractorConfig &config = {},
|
||||||
ContextGraphPtr context_graph = {});
|
ContextGraphPtr context_graph = {});
|
||||||
|
|
||||||
explicit OfflineStream(WhisperTag tag);
|
explicit OfflineStream(WhisperTag tag);
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||||
float blank_penalty)
|
float blank_penalty)
|
||||||
: model_(model), blank_penalty_(blank_penalty) {}
|
: model_(model), blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
|
|||||||
@@ -0,0 +1,117 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iterator>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
|
||||||
|
int32_t token, OrtAllocator *allocator) {
|
||||||
|
std::array<int64_t, 2> shape{1, 1};
|
||||||
|
|
||||||
|
Ort::Value decoder_input =
|
||||||
|
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> length_shape{1};
|
||||||
|
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
|
||||||
|
allocator, length_shape.data(), length_shape.size());
|
||||||
|
|
||||||
|
int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
|
||||||
|
|
||||||
|
int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
|
||||||
|
|
||||||
|
p[0] = token;
|
||||||
|
|
||||||
|
p_length[0] = 1;
|
||||||
|
|
||||||
|
return {std::move(decoder_input), std::move(decoder_input_length)};
|
||||||
|
}
|
||||||
|
|
||||||
|
static OfflineTransducerDecoderResult DecodeOne(
|
||||||
|
const float *p, int32_t num_rows, int32_t num_cols,
|
||||||
|
OfflineTransducerNeMoModel *model, float blank_penalty) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
OfflineTransducerDecoderResult ans;
|
||||||
|
|
||||||
|
int32_t vocab_size = model->VocabSize();
|
||||||
|
int32_t blank_id = vocab_size - 1;
|
||||||
|
|
||||||
|
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
|
||||||
|
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||||
|
std::move(decoder_input_pair.second),
|
||||||
|
model->GetDecoderInitStates(1));
|
||||||
|
|
||||||
|
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != num_rows; ++t) {
|
||||||
|
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
|
||||||
|
encoder_shape.data(), encoder_shape.size());
|
||||||
|
|
||||||
|
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
|
||||||
|
View(&decoder_output_pair.first));
|
||||||
|
|
||||||
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
if (blank_penalty > 0) {
|
||||||
|
p_logit[blank_id] -= blank_penalty;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
|
static_cast<const float *>(p_logit),
|
||||||
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
|
|
||||||
|
if (y != blank_id) {
|
||||||
|
ans.tokens.push_back(y);
|
||||||
|
ans.timestamps.push_back(t);
|
||||||
|
|
||||||
|
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
|
||||||
|
|
||||||
|
decoder_output_pair =
|
||||||
|
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||||
|
std::move(decoder_input_pair.second),
|
||||||
|
std::move(decoder_output_pair.second));
|
||||||
|
} // if (y != blank_id)
|
||||||
|
} // for (int32_t i = 0; i != num_rows; ++i)
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult>
|
||||||
|
OfflineTransducerGreedySearchNeMoDecoder::Decode(
|
||||||
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
|
||||||
|
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||||
|
int32_t dim1 = static_cast<int32_t>(shape[1]);
|
||||||
|
int32_t dim2 = static_cast<int32_t>(shape[2]);
|
||||||
|
|
||||||
|
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
|
||||||
|
const float *p = encoder_out.GetTensorData<float>();
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
const float *this_p = p + dim1 * dim2 * i;
|
||||||
|
int32_t this_len = p_length[i];
|
||||||
|
|
||||||
|
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineTransducerGreedySearchNeMoDecoder
|
||||||
|
: public OfflineTransducerDecoder {
|
||||||
|
public:
|
||||||
|
OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
|
||||||
|
float blank_penalty)
|
||||||
|
: model_(model), blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineTransducerNeMoModel *model_; // Not owned
|
||||||
|
float blank_penalty_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||||
301
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
Normal file
301
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-nemo-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineTransducerNeMoModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.transducer.encoder_filename);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.transducer.decoder_filename);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.transducer.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||||
|
Ort::Value features_length) {
|
||||||
|
// (B, T, C) -> (B, C, T)
|
||||||
|
features = Transpose12(allocator_, &features);
|
||||||
|
|
||||||
|
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||||
|
std::move(features_length)};
|
||||||
|
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||||
|
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||||
|
encoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return encoder_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||||
|
Ort::Value targets, Ort::Value targets_length,
|
||||||
|
std::vector<Ort::Value> states) {
|
||||||
|
std::vector<Ort::Value> decoder_inputs;
|
||||||
|
decoder_inputs.reserve(2 + states.size());
|
||||||
|
|
||||||
|
decoder_inputs.push_back(std::move(targets));
|
||||||
|
decoder_inputs.push_back(std::move(targets_length));
|
||||||
|
|
||||||
|
for (auto &s : states) {
|
||||||
|
decoder_inputs.push_back(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto decoder_out = decoder_sess_->Run(
|
||||||
|
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
|
||||||
|
decoder_inputs.size(), decoder_output_names_ptr_.data(),
|
||||||
|
decoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
std::vector<Ort::Value> states_next;
|
||||||
|
states_next.reserve(states.size());
|
||||||
|
|
||||||
|
// decoder_out[0]: decoder_output
|
||||||
|
// decoder_out[1]: decoder_output_length
|
||||||
|
// decoder_out[2:] states_next
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != states.size(); ++i) {
|
||||||
|
states_next.push_back(std::move(decoder_out[i + 2]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// we discard decoder_out[1]
|
||||||
|
return {std::move(decoder_out[0]), std::move(states_next)};
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
|
||||||
|
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||||
|
std::move(decoder_out)};
|
||||||
|
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
|
||||||
|
joiner_input.data(), joiner_input.size(),
|
||||||
|
joiner_output_names_ptr_.data(),
|
||||||
|
joiner_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(logit[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
||||||
|
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||||
|
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||||
|
s0_shape.size());
|
||||||
|
|
||||||
|
Fill<float>(&s0, 0);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||||
|
|
||||||
|
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||||
|
s1_shape.size());
|
||||||
|
|
||||||
|
Fill<float>(&s1, 0);
|
||||||
|
|
||||||
|
std::vector<Ort::Value> states;
|
||||||
|
|
||||||
|
states.reserve(2);
|
||||||
|
states.push_back(std::move(s0));
|
||||||
|
states.push_back(std::move(s1));
|
||||||
|
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---encoder---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
|
||||||
|
// need to increase by 1 since the blank token is not included in computing
|
||||||
|
// vocab_size in NeMo.
|
||||||
|
vocab_size_ += 1;
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
|
||||||
|
|
||||||
|
if (normalize_type_ == "NA") {
|
||||||
|
normalize_type_ = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||||
|
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||||
|
&joiner_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||||
|
&joiner_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_input_names_;
|
||||||
|
std::vector<const char *> joiner_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_output_names_;
|
||||||
|
std::vector<const char *> joiner_output_names_ptr_;
|
||||||
|
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
int32_t subsampling_factor_ = 8;
|
||||||
|
std::string normalize_type_;
|
||||||
|
int32_t pred_rnn_layers_ = -1;
|
||||||
|
int32_t pred_hidden_ = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||||
|
AAssetManager *mgr, const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OfflineTransducerNeMoModel::RunEncoder(
|
||||||
|
Ort::Value features, Ort::Value features_length) const {
|
||||||
|
return impl_->RunEncoder(std::move(features), std::move(features_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||||
|
OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets,
|
||||||
|
Ort::Value targets_length,
|
||||||
|
std::vector<Ort::Value> states) const {
|
||||||
|
return impl_->RunDecoder(std::move(targets), std::move(targets_length),
|
||||||
|
std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OfflineTransducerNeMoModel::GetDecoderInitStates(
|
||||||
|
int32_t batch_size) const {
|
||||||
|
return impl_->GetDecoderInitStates(batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||||
|
Ort::Value decoder_out) const {
|
||||||
|
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const {
|
||||||
|
return impl_->SubsamplingFactor();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineTransducerNeMoModel::VocabSize() const {
|
||||||
|
return impl_->VocabSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineTransducerNeMoModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||||
|
return impl_->FeatureNormalizationMethod();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
103
sherpa-onnx/csrc/offline-transducer-nemo-model.h
Normal file
103
sherpa-onnx/csrc/offline-transducer-nemo-model.h
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-nemo-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// see
|
||||||
|
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
|
||||||
|
// Its decoder is stateful, not stateless.
|
||||||
|
class OfflineTransducerNeMoModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineTransducerNeMoModel(AAssetManager *mgr,
|
||||||
|
const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OfflineTransducerNeMoModel();
|
||||||
|
|
||||||
|
/** Run the encoder.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int64_t.
|
||||||
|
*
|
||||||
|
* @return Return a vector containing:
|
||||||
|
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
|
||||||
|
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
|
||||||
|
* of frames in `encoder_out` before padding.
|
||||||
|
*/
|
||||||
|
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||||
|
Ort::Value features_length) const;
|
||||||
|
|
||||||
|
/** Run the decoder network.
|
||||||
|
*
|
||||||
|
* @param targets A int32 tensor of shape (batch_size, 1)
|
||||||
|
* @param targets_length A int32 tensor of shape (batch_size,)
|
||||||
|
* @param states The states for the decoder model.
|
||||||
|
* @return Return a vector:
|
||||||
|
* - ans[0] is the decoder_out (a float tensor)
|
||||||
|
* - ans[1] is the decoder_out_length (a int32 tensor)
|
||||||
|
* - ans[2:] is the states_next
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||||
|
Ort::Value targets, Ort::Value targets_length,
|
||||||
|
std::vector<Ort::Value> states) const;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
|
||||||
|
|
||||||
|
/** Run the joint network.
|
||||||
|
*
|
||||||
|
* @param encoder_out Output of the encoder network.
|
||||||
|
* @param decoder_out Output of the decoder network.
|
||||||
|
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
|
||||||
|
*/
|
||||||
|
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
|
||||||
|
|
||||||
|
/** Return the subsampling factor of the model.
|
||||||
|
*/
|
||||||
|
int32_t SubsamplingFactor() const;
|
||||||
|
|
||||||
|
int32_t VocabSize() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
// Possible values:
|
||||||
|
// - per_feature
|
||||||
|
// - all_features (not implemented yet)
|
||||||
|
// - fixed_mean (not implemented)
|
||||||
|
// - fixed_std (not implemented)
|
||||||
|
// - or just leave it to empty
|
||||||
|
// See
|
||||||
|
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||||
|
// for details
|
||||||
|
std::string FeatureNormalizationMethod() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||||
@@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void InitDecoder() {
|
void InitDecoder() {
|
||||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
|
||||||
!sym_.contains("<blank>")) {
|
!sym_.Contains("<blank>")) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"We expect that tokens.txt contains "
|
"We expect that tokens.txt contains "
|
||||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||||
@@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t blank_id = 0;
|
int32_t blank_id = 0;
|
||||||
if (sym_.contains("<blk>")) {
|
if (sym_.Contains("<blk>")) {
|
||||||
blank_id = sym_["<blk>"];
|
blank_id = sym_["<blk>"];
|
||||||
} else if (sym_.contains("<eps>")) {
|
} else if (sym_.Contains("<eps>")) {
|
||||||
// for tdnn models of the yesno recipe from icefall
|
// for tdnn models of the yesno recipe from icefall
|
||||||
blank_id = sym_["<eps>"];
|
blank_id = sym_["<eps>"];
|
||||||
} else if (sym_.contains("<blank>")) {
|
} else if (sym_.Contains("<blank>")) {
|
||||||
// for WeNet CTC models
|
// for WeNet CTC models
|
||||||
blank_id = sym_["<blank>"];
|
blank_id = sym_["<blank>"];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
sym_(config.model_config.tokens),
|
sym_(config.model_config.tokens),
|
||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
if (sym_.contains("<unk>")) {
|
if (sym_.Contains("<unk>")) {
|
||||||
unk_id_ = sym_["<unk>"];
|
unk_id_ = sym_["<unk>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(),
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
lm_.get(),
|
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
||||||
config_.max_active_paths,
|
|
||||||
config_.lm_config.scale,
|
|
||||||
unk_id_,
|
|
||||||
config_.blank_penalty,
|
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(),
|
model_.get(), unk_id_, config_.blank_penalty,
|
||||||
unk_id_,
|
|
||||||
config_.blank_penalty,
|
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||||
sym_(mgr, config.model_config.tokens),
|
sym_(mgr, config.model_config.tokens),
|
||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
if (sym_.contains("<unk>")) {
|
if (sym_.Contains("<unk>")) {
|
||||||
unk_id_ = sym_["<unk>"];
|
unk_id_ = sym_["<unk>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(),
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
lm_.get(),
|
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
||||||
config_.max_active_paths,
|
|
||||||
config_.lm_config.scale,
|
|
||||||
unk_id_,
|
|
||||||
config_.blank_penalty,
|
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(),
|
model_.get(), unk_id_, config_.blank_penalty,
|
||||||
unk_id_,
|
|
||||||
config_.blank_penalty,
|
|
||||||
config_.temperature_scale);
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ namespace sherpa_onnx {
|
|||||||
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
|
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
|
||||||
*
|
*
|
||||||
* @param allocator
|
* @param allocator
|
||||||
* @param v A 2-D tensor. Its data type is T.
|
* @param v A 3-D tensor. Its data type is T.
|
||||||
* @param dim0_start Start index of the first dimension..
|
* @param dim0_start Start index of the first dimension..
|
||||||
* @param dim0_end End index of the first dimension..
|
* @param dim0_end End index of the first dimension..
|
||||||
* @param dim1_start Start index of the second dimension.
|
* @param dim1_start Start index of the second dimension.
|
||||||
|
|||||||
@@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const {
|
|||||||
return sym2id_.at(sym);
|
return sym2id_.at(sym);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
||||||
|
|
||||||
bool SymbolTable::contains(const std::string &sym) const {
|
bool SymbolTable::Contains(const std::string &sym) const {
|
||||||
return sym2id_.count(sym) != 0;
|
return sym2id_.count(sym) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,14 +40,16 @@ class SymbolTable {
|
|||||||
int32_t operator[](const std::string &sym) const;
|
int32_t operator[](const std::string &sym) const;
|
||||||
|
|
||||||
/// Return true if there is a symbol with the given ID.
|
/// Return true if there is a symbol with the given ID.
|
||||||
bool contains(int32_t id) const;
|
bool Contains(int32_t id) const;
|
||||||
|
|
||||||
/// Return true if there is a given symbol in the symbol table.
|
/// Return true if there is a given symbol in the symbol table.
|
||||||
bool contains(const std::string &sym) const;
|
bool Contains(const std::string &sym) const;
|
||||||
|
|
||||||
// for tokens.txt from Whisper
|
// for tokens.txt from Whisper
|
||||||
void ApplyBase64Decode();
|
void ApplyBase64Decode();
|
||||||
|
|
||||||
|
int32_t NumSymbols() const { return id2sym_.size(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(std::istream &is);
|
void Init(std::istream &is);
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
|||||||
word = word.replace(0, 3, " ");
|
word = word.replace(0, 3, " ");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (symbol_table.contains(word)) {
|
if (symbol_table.Contains(word)) {
|
||||||
int32_t id = symbol_table[word];
|
int32_t id = symbol_table[word];
|
||||||
tmp_ids.push_back(id);
|
tmp_ids.push_back(id);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ namespace sherpa_onnx {
|
|||||||
static void PybindOfflineRecognizerConfig(py::module *m) {
|
static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||||
using PyClass = OfflineRecognizerConfig;
|
using PyClass = OfflineRecognizerConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
||||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
||||||
const OfflineCtcFstDecoderConfig &, const std::string &,
|
const std::string &, int32_t, const std::string &, float,
|
||||||
int32_t, const std::string &, float, float>(),
|
float>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OfflineLMConfig(),
|
py::arg("lm_config") = OfflineLMConfig(),
|
||||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ Args:
|
|||||||
static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
||||||
using PyClass = OfflineRecognitionResult;
|
using PyClass = OfflineRecognitionResult;
|
||||||
py::class_<PyClass>(*m, "OfflineRecognitionResult")
|
py::class_<PyClass>(*m, "OfflineRecognitionResult")
|
||||||
|
.def("__str__", &PyClass::AsJsonString)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"text",
|
"text",
|
||||||
[](const PyClass &self) -> py::str {
|
[](const PyClass &self) -> py::str {
|
||||||
@@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
|||||||
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
|
||||||
using PyClass = OfflineFeatureExtractorConfig;
|
|
||||||
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
|
|
||||||
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
|
|
||||||
py::arg("feature_dim") = 80)
|
|
||||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
|
||||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
|
||||||
.def("__str__", &PyClass::ToString);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PybindOfflineStream(py::module *m) {
|
void PybindOfflineStream(py::module *m) {
|
||||||
PybindOfflineFeatureExtractorConfig(m);
|
|
||||||
PybindOfflineRecognitionResult(m);
|
PybindOfflineRecognitionResult(m);
|
||||||
|
|
||||||
using PyClass = OfflineStream;
|
using PyClass = OfflineStream;
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from pathlib import Path
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
|
FeatureExtractorConfig,
|
||||||
OfflineCtcFstDecoderConfig,
|
OfflineCtcFstDecoderConfig,
|
||||||
OfflineFeatureExtractorConfig,
|
|
||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
OfflineNemoEncDecCtcModelConfig,
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
OfflineParaformerModelConfig,
|
||||||
@@ -51,6 +51,7 @@ class OfflineRecognizer(object):
|
|||||||
blank_penalty: float = 0.0,
|
blank_penalty: float = 0.0,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
|
model_type: str = "transducer",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -106,10 +107,10 @@ class OfflineRecognizer(object):
|
|||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type="transducer",
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
feature_dim=feature_dim,
|
feature_dim=feature_dim,
|
||||||
)
|
)
|
||||||
@@ -182,7 +183,7 @@ class OfflineRecognizer(object):
|
|||||||
model_type="paraformer",
|
model_type="paraformer",
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
feature_dim=feature_dim,
|
feature_dim=feature_dim,
|
||||||
)
|
)
|
||||||
@@ -246,7 +247,7 @@ class OfflineRecognizer(object):
|
|||||||
model_type="nemo_ctc",
|
model_type="nemo_ctc",
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
feature_dim=feature_dim,
|
feature_dim=feature_dim,
|
||||||
)
|
)
|
||||||
@@ -326,7 +327,7 @@ class OfflineRecognizer(object):
|
|||||||
model_type="whisper",
|
model_type="whisper",
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
)
|
)
|
||||||
@@ -389,7 +390,7 @@ class OfflineRecognizer(object):
|
|||||||
model_type="tdnn",
|
model_type="tdnn",
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
feature_dim=feature_dim,
|
feature_dim=feature_dim,
|
||||||
)
|
)
|
||||||
@@ -453,7 +454,7 @@ class OfflineRecognizer(object):
|
|||||||
model_type="wenet_ctc",
|
model_type="wenet_ctc",
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = FeatureExtractorConfig(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
feature_dim=feature_dim,
|
feature_dim=feature_dim,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user