Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)
This commit is contained in:
99
.github/scripts/test-offline-transducer.sh
vendored
99
.github/scripts/test-offline-transducer.sh
vendored
@@ -13,6 +13,105 @@ echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------------------"
|
||||
log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)"
|
||||
log "------------------------------------------------------------------------"
|
||||
|
||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2
|
||||
name=$(basename $url)
|
||||
curl -SL -O $url
|
||||
tar xvf $name
|
||||
rm $name
|
||||
repo=$(basename -s .tar.bz2 $name)
|
||||
ls -lh $repo
|
||||
|
||||
log "test $repo"
|
||||
test_wavs=(
|
||||
de-german.wav
|
||||
es-spanish.wav
|
||||
hr-croatian.wav
|
||||
po-polish.wav
|
||||
uk-ukrainian.wav
|
||||
en-english.wav
|
||||
fr-french.wav
|
||||
it-italian.wav
|
||||
ru-russian.wav
|
||||
)
|
||||
for w in ${test_wavs[@]}; do
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--debug=1 \
|
||||
$repo/test_wavs/$w
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2
|
||||
name=$(basename $url)
|
||||
curl -SL -O $url
|
||||
tar xvf $name
|
||||
rm $name
|
||||
repo=$(basename -s .tar.bz2 $name)
|
||||
ls -lh $repo
|
||||
|
||||
log "Test $repo"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--debug=1 \
|
||||
$repo/test_wavs/en-english.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2
|
||||
name=$(basename $url)
|
||||
curl -SL -O $url
|
||||
tar xvf $name
|
||||
rm $name
|
||||
repo=$(basename -s .tar.bz2 $name)
|
||||
ls -lh $repo
|
||||
|
||||
log "test $repo"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--debug=1 \
|
||||
$repo/test_wavs/es-spanish.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2
|
||||
name=$(basename $url)
|
||||
curl -SL -O $url
|
||||
tar xvf $name
|
||||
rm $name
|
||||
repo=$(basename -s .tar.bz2 $name)
|
||||
ls -lh $repo
|
||||
|
||||
log "Test $repo"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder.onnx \
|
||||
--decoder=$repo/decoder.onnx \
|
||||
--joiner=$repo/joiner.onnx \
|
||||
--debug=1 \
|
||||
$repo/test_wavs/en-english.wav \
|
||||
$repo/test_wavs/de-german.wav \
|
||||
$repo/test_wavs/fr-french.wav \
|
||||
$repo/test_wavs/es-spanish.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Conformer transducer (English)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
16
.github/workflows/linux.yaml
vendored
16
.github/workflows/linux.yaml
vendored
@@ -128,6 +128,14 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: install/*
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test spoken language identification (C++ API)
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -215,14 +223,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-paraformer.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
16
.github/workflows/macos.yaml
vendored
16
.github/workflows/macos.yaml
vendored
@@ -107,6 +107,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -192,14 +200,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-offline-ctc.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||
sherpa-onnx-ced-*
|
||||
node_modules
|
||||
package-lock.json
|
||||
sherpa-onnx-nemo-*
|
||||
|
||||
68
python-api-examples/offline-nemo-ctc-decode-files.py
Executable file
68
python-api-examples/offline-nemo-ctc-decode-files.py
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming CTC model from NeMo
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
|
||||
The example model supports 10 languages and it is converted from
|
||||
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx"
|
||||
tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
|
||||
|
||||
test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
|
||||
|
||||
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 16000 Hz
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
print(wave_filename)
|
||||
print(stream.result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
73
python-api-examples/offline-nemo-transducer-decode-files.py
Executable file
73
python-api-examples/offline-nemo-transducer-decode-files.py
Executable file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming transducer model from NeMo
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
|
||||
The example model supports 10 languages and it is converted from
|
||||
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx"
|
||||
decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx"
|
||||
joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx"
|
||||
tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
|
||||
|
||||
test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
|
||||
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
|
||||
|
||||
if not Path(encoder).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
tokens=tokens,
|
||||
model_type="nemo_transducer",
|
||||
debug=True,
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 16000 Hz
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
print(wave_filename)
|
||||
print(stream.result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -40,9 +40,11 @@ set(sources
|
||||
offline-tdnn-ctc-model.cc
|
||||
offline-tdnn-model-config.cc
|
||||
offline-transducer-greedy-search-decoder.cc
|
||||
offline-transducer-greedy-search-nemo-decoder.cc
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
offline-transducer-nemo-model.cc
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-wenet-ctc-model.cc
|
||||
offline-whisper-greedy-search-decoder.cc
|
||||
|
||||
@@ -56,6 +56,19 @@ struct FeatureExtractorConfig {
|
||||
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
|
||||
std::string window_type = "povey"; // e.g. Hamming window
|
||||
|
||||
// For models from NeMo
|
||||
// This option is not exposed and is set internally when loading models.
|
||||
// Possible values:
|
||||
// - per_feature
|
||||
// - all_features (not implemented yet)
|
||||
// - fixed_mean (not implemented)
|
||||
// - fixed_std (not implemented)
|
||||
// - or just leave it to empty
|
||||
// See
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||
// for details
|
||||
std::string nemo_normalize_type;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
@@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
std::string text;
|
||||
|
||||
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||
if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
|
||||
if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
|
||||
// tdnn models from yesno have a SIL token, we should remove it.
|
||||
continue;
|
||||
}
|
||||
@@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>") &&
|
||||
!symbol_table_.contains("<eps>") &&
|
||||
!symbol_table_.contains("<blank>")) {
|
||||
if (!symbol_table_.Contains("<blk>") &&
|
||||
!symbol_table_.Contains("<eps>") &&
|
||||
!symbol_table_.Contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
@@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (symbol_table_.contains("<blk>")) {
|
||||
if (symbol_table_.Contains("<blk>")) {
|
||||
blank_id = symbol_table_["<blk>"];
|
||||
} else if (symbol_table_.contains("<eps>")) {
|
||||
} else if (symbol_table_.Contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = symbol_table_["<eps>"];
|
||||
} else if (symbol_table_.contains("<blank>")) {
|
||||
} else if (symbol_table_.Contains("<blank>")) {
|
||||
// for Wenet CTC models
|
||||
blank_id = symbol_table_["<blank>"];
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
@@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const auto &model_type = config.model_config.model_type;
|
||||
if (model_type == "transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
} else if (model_type == "nemo_transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
@@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||
!config.model_config.transducer.decoder_filename.empty() &&
|
||||
!config.model_config.transducer.joiner_filename.empty()) {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" ||
|
||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
@@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const auto &model_type = config.model_config.model_type;
|
||||
if (model_type == "transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
@@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||
!config.model_config.transducer.decoder_filename.empty() &&
|
||||
!config.model_config.transducer.joiner_filename.empty()) {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" ||
|
||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
|
||||
182
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Normal file
182
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Normal file
@@ -0,0 +1,182 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ./offline-recognizer-transducer-impl.h
|
||||
OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor);
|
||||
|
||||
class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||
config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
PostInit();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||
mgr, config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
PostInit();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<Ort::Value> features;
|
||||
|
||||
features.reserve(n);
|
||||
|
||||
std::vector<std::vector<float>> features_vec(n);
|
||||
std::vector<int64_t> features_length_vec(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto f = ss[i]->GetFrames();
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
features_length_vec[i] = num_frames;
|
||||
features_vec[i] = std::move(f);
|
||||
|
||||
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(
|
||||
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||
shape.data(), shape.size());
|
||||
features.push_back(std::move(x));
|
||||
}
|
||||
|
||||
std::vector<const Ort::Value *> features_pointer(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
features_pointer[i] = &features[i];
|
||||
}
|
||||
|
||||
std::array<int64_t, 1> features_length_shape = {n};
|
||||
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||
memory_info, features_length_vec.data(), n,
|
||||
features_length_shape.data(), features_length_shape.size());
|
||||
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||
// t[0] encoder_out, float tensor, (batch_size, dim, T)
|
||||
// t[1] encoder_out_length, int64 tensor, (batch_size,)
|
||||
|
||||
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
|
||||
|
||||
auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1]));
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
config_.feat_config.low_freq = 0;
|
||||
// config_.feat_config.high_freq = 8000;
|
||||
config_.feat_config.is_librosa = true;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
// config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.dither = 0;
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
// check the blank ID
|
||||
if (!symbol_table_.Contains("<blk>")) {
|
||||
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (symbol_table_["<blk>"] != vocab_size - 1) {
|
||||
SHERPA_ONNX_LOGE("<blk> is not the last token!");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (symbol_table_.NumSymbols() != vocab_size) {
|
||||
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
|
||||
symbol_table_.NumSymbols(), vocab_size);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineTransducerNeMoModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
|
||||
@@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.contains(i)) {
|
||||
if (!sym_table.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
@@ -26,7 +27,7 @@ namespace sherpa_onnx {
|
||||
struct OfflineRecognitionResult;
|
||||
|
||||
struct OfflineRecognizerConfig {
|
||||
OfflineFeatureExtractorConfig feat_config;
|
||||
FeatureExtractorConfig feat_config;
|
||||
OfflineModelConfig model_config;
|
||||
OfflineLMConfig lm_config;
|
||||
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
@@ -44,7 +45,7 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
OfflineRecognizerConfig() = default;
|
||||
OfflineRecognizerConfig(
|
||||
const OfflineFeatureExtractorConfig &feat_config,
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
const std::string &decoding_method, int32_t max_active_paths,
|
||||
|
||||
@@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
|
||||
}
|
||||
}
|
||||
|
||||
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. "
|
||||
"Note: You can have a different "
|
||||
"sample rate for the input waveform. We will do resampling "
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
}
|
||||
|
||||
std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineFeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
||||
explicit Impl(const FeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: config_(config), context_graph_(context_graph) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.dither = config.dither;
|
||||
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||
opts_.frame_opts.window_type = config.window_type;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
// Please see
|
||||
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
|
||||
// and
|
||||
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
||||
opts_.mel_opts.high_freq = -400;
|
||||
opts_.mel_opts.high_freq = config.high_freq;
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
|
||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
@@ -237,7 +220,7 @@ class OfflineStream::Impl {
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineFeatureExtractorConfig config_;
|
||||
FeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
@@ -245,9 +228,8 @@ class OfflineStream::Impl {
|
||||
ContextGraphPtr context_graph_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -32,46 +33,12 @@ struct OfflineRecognitionResult {
|
||||
std::string AsJsonString() const;
|
||||
};
|
||||
|
||||
struct OfflineFeatureExtractorConfig {
|
||||
// Sampling rate used by the feature extractor. If it is different from
|
||||
// the sampling rate of the input waveform, we will do resampling inside.
|
||||
int32_t sampling_rate = 16000;
|
||||
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// Set internally by some models, e.g., paraformer and wenet CTC models set
|
||||
// it to false.
|
||||
// This parameter is not exposed to users from the commandline
|
||||
// If true, the feature extractor expects inputs to be normalized to
|
||||
// the range [-1, 1].
|
||||
// If false, we will multiply the inputs by 32768
|
||||
bool normalize_samples = true;
|
||||
|
||||
// For models from NeMo
|
||||
// This option is not exposed and is set internally when loading models.
|
||||
// Possible values:
|
||||
// - per_feature
|
||||
// - all_features (not implemented yet)
|
||||
// - fixed_mean (not implemented)
|
||||
// - fixed_std (not implemented)
|
||||
// - or just leave it to empty
|
||||
// See
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||
// for details
|
||||
std::string nemo_normalize_type;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
};
|
||||
|
||||
struct WhisperTag {};
|
||||
struct CEDTag {};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
explicit OfflineStream(const FeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = {});
|
||||
|
||||
explicit OfflineStream(WhisperTag tag);
|
||||
|
||||
@@ -14,8 +14,8 @@ namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
public:
|
||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||
float blank_penalty)
|
||||
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||
float blank_penalty)
|
||||
: model_(model), blank_penalty_(blank_penalty) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
|
||||
int32_t token, OrtAllocator *allocator) {
|
||||
std::array<int64_t, 2> shape{1, 1};
|
||||
|
||||
Ort::Value decoder_input =
|
||||
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
|
||||
|
||||
std::array<int64_t, 1> length_shape{1};
|
||||
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
|
||||
allocator, length_shape.data(), length_shape.size());
|
||||
|
||||
int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
|
||||
|
||||
int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
|
||||
|
||||
p[0] = token;
|
||||
|
||||
p_length[0] = 1;
|
||||
|
||||
return {std::move(decoder_input), std::move(decoder_input_length)};
|
||||
}
|
||||
|
||||
static OfflineTransducerDecoderResult DecodeOne(
|
||||
const float *p, int32_t num_rows, int32_t num_cols,
|
||||
OfflineTransducerNeMoModel *model, float blank_penalty) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
OfflineTransducerDecoderResult ans;
|
||||
|
||||
int32_t vocab_size = model->VocabSize();
|
||||
int32_t blank_id = vocab_size - 1;
|
||||
|
||||
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
|
||||
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||
std::move(decoder_input_pair.second),
|
||||
model->GetDecoderInitStates(1));
|
||||
|
||||
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
|
||||
|
||||
for (int32_t t = 0; t != num_rows; ++t) {
|
||||
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
|
||||
encoder_shape.data(), encoder_shape.size());
|
||||
|
||||
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
|
||||
View(&decoder_output_pair.first));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
if (blank_penalty > 0) {
|
||||
p_logit[blank_id] -= blank_penalty;
|
||||
}
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
|
||||
if (y != blank_id) {
|
||||
ans.tokens.push_back(y);
|
||||
ans.timestamps.push_back(t);
|
||||
|
||||
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
|
||||
|
||||
decoder_output_pair =
|
||||
model->RunDecoder(std::move(decoder_input_pair.first),
|
||||
std::move(decoder_input_pair.second),
|
||||
std::move(decoder_output_pair.second));
|
||||
} // if (y != blank_id)
|
||||
} // for (int32_t i = 0; i != num_rows; ++i)
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerGreedySearchNeMoDecoder::Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
|
||||
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||
int32_t dim1 = static_cast<int32_t>(shape[1]);
|
||||
int32_t dim2 = static_cast<int32_t>(shape[2]);
|
||||
|
||||
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
|
||||
const float *p = encoder_out.GetTensorData<float>();
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *this_p = p + dim1 * dim2 * i;
|
||||
int32_t this_len = p_length[i];
|
||||
|
||||
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerGreedySearchNeMoDecoder
|
||||
: public OfflineTransducerDecoder {
|
||||
public:
|
||||
OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
|
||||
float blank_penalty)
|
||||
: model_(model), blank_penalty_(blank_penalty) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerNeMoModel *model_; // Not owned
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
|
||||
301
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
Normal file
301
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
Normal file
@@ -0,0 +1,301 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-nemo-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerNeMoModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
// (B, T, C) -> (B, C, T)
|
||||
features = Transpose12(allocator_, &features);
|
||||
|
||||
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
return encoder_out;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
Ort::Value targets, Ort::Value targets_length,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> decoder_inputs;
|
||||
decoder_inputs.reserve(2 + states.size());
|
||||
|
||||
decoder_inputs.push_back(std::move(targets));
|
||||
decoder_inputs.push_back(std::move(targets_length));
|
||||
|
||||
for (auto &s : states) {
|
||||
decoder_inputs.push_back(std::move(s));
|
||||
}
|
||||
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
|
||||
decoder_inputs.size(), decoder_output_names_ptr_.data(),
|
||||
decoder_output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> states_next;
|
||||
states_next.reserve(states.size());
|
||||
|
||||
// decoder_out[0]: decoder_output
|
||||
// decoder_out[1]: decoder_output_length
|
||||
// decoder_out[2:] states_next
|
||||
|
||||
for (int32_t i = 0; i != states.size(); ++i) {
|
||||
states_next.push_back(std::move(decoder_out[i + 2]));
|
||||
}
|
||||
|
||||
// we discard decoder_out[1]
|
||||
return {std::move(decoder_out[0]), std::move(states_next)};
|
||||
}
|
||||
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
|
||||
joiner_input.data(), joiner_input.size(),
|
||||
joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
|
||||
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
|
||||
s0_shape.size());
|
||||
|
||||
Fill<float>(&s0, 0);
|
||||
|
||||
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
|
||||
|
||||
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
|
||||
s1_shape.size());
|
||||
|
||||
Fill<float>(&s1, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
states.reserve(2);
|
||||
states.push_back(std::move(s0));
|
||||
states.push_back(std::move(s1));
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
|
||||
// need to increase by 1 since the blank token is not included in computing
|
||||
// vocab_size in NeMo.
|
||||
vocab_size_ += 1;
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
|
||||
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
|
||||
|
||||
if (normalize_type_ == "NA") {
|
||||
normalize_type_ = "";
|
||||
}
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_input_names_;
|
||||
std::vector<const char *> joiner_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_output_names_;
|
||||
std::vector<const char *> joiner_output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0;
|
||||
int32_t subsampling_factor_ = 8;
|
||||
std::string normalize_type_;
|
||||
int32_t pred_rnn_layers_ = -1;
|
||||
int32_t pred_hidden_ = -1;
|
||||
};
|
||||
|
||||
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineTransducerNeMoModel::RunEncoder(
|
||||
Ort::Value features, Ort::Value features_length) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets,
|
||||
Ort::Value targets_length,
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->RunDecoder(std::move(targets), std::move(targets_length),
|
||||
std::move(states));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OfflineTransducerNeMoModel::GetDecoderInitStates(
|
||||
int32_t batch_size) const {
|
||||
return impl_->GetDecoderInitStates(batch_size);
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
|
||||
Ort::Value decoder_out) const {
|
||||
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
|
||||
}
|
||||
|
||||
int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
int32_t OfflineTransducerNeMoModel::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineTransducerNeMoModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||
return impl_->FeatureNormalizationMethod();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
103
sherpa-onnx/csrc/offline-transducer-nemo-model.h
Normal file
103
sherpa-onnx/csrc/offline-transducer-nemo-model.h
Normal file
@@ -0,0 +1,103 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-nemo-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// see
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
|
||||
// Its decoder is stateful, not stateless.
|
||||
class OfflineTransducerNeMoModel {
|
||||
public:
|
||||
explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTransducerNeMoModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineTransducerNeMoModel();
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
|
||||
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
|
||||
* of frames in `encoder_out` before padding.
|
||||
*/
|
||||
std::vector<Ort::Value> RunEncoder(Ort::Value features,
|
||||
Ort::Value features_length) const;
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
* @param targets A int32 tensor of shape (batch_size, 1)
|
||||
* @param targets_length A int32 tensor of shape (batch_size,)
|
||||
* @param states The states for the decoder model.
|
||||
* @return Return a vector:
|
||||
* - ans[0] is the decoder_out (a float tensor)
|
||||
* - ans[1] is the decoder_out_length (a int32 tensor)
|
||||
* - ans[2:] is the states_next
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
Ort::Value targets, Ort::Value targets_length,
|
||||
std::vector<Ort::Value> states) const;
|
||||
|
||||
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
|
||||
|
||||
/** Run the joint network.
|
||||
*
|
||||
* @param encoder_out Output of the encoder network.
|
||||
* @param decoder_out Output of the decoder network.
|
||||
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
|
||||
*/
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
|
||||
|
||||
/** Return the subsampling factor of the model.
|
||||
*/
|
||||
int32_t SubsamplingFactor() const;
|
||||
|
||||
int32_t VocabSize() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
// Possible values:
|
||||
// - per_feature
|
||||
// - all_features (not implemented yet)
|
||||
// - fixed_mean (not implemented)
|
||||
// - fixed_std (not implemented)
|
||||
// - or just leave it to empty
|
||||
// See
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||
// for details
|
||||
std::string FeatureNormalizationMethod() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
|
||||
@@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
|
||||
!sym_.Contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
@@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
if (sym_.Contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
} else if (sym_.Contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
} else if (sym_.Contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
@@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
model_.get(), unk_id_, config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
@@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
@@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
model_.get(), unk_id_, config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace sherpa_onnx {
|
||||
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 2-D tensor. Its data type is T.
|
||||
* @param v A 3-D tensor. Its data type is T.
|
||||
* @param dim0_start Start index of the first dimension..
|
||||
* @param dim0_end End index of the first dimension..
|
||||
* @param dim1_start Start index of the second dimension.
|
||||
|
||||
@@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||
return sym2id_.at(sym);
|
||||
}
|
||||
|
||||
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
||||
bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
||||
|
||||
bool SymbolTable::contains(const std::string &sym) const {
|
||||
bool SymbolTable::Contains(const std::string &sym) const {
|
||||
return sym2id_.count(sym) != 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,14 +40,16 @@ class SymbolTable {
|
||||
int32_t operator[](const std::string &sym) const;
|
||||
|
||||
/// Return true if there is a symbol with the given ID.
|
||||
bool contains(int32_t id) const;
|
||||
bool Contains(int32_t id) const;
|
||||
|
||||
/// Return true if there is a given symbol in the symbol table.
|
||||
bool contains(const std::string &sym) const;
|
||||
bool Contains(const std::string &sym) const;
|
||||
|
||||
// for tokens.txt from Whisper
|
||||
void ApplyBase64Decode();
|
||||
|
||||
int32_t NumSymbols() const { return id2sym_.size(); }
|
||||
|
||||
private:
|
||||
void Init(std::istream &is);
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
word = word.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
if (symbol_table.contains(word)) {
|
||||
if (symbol_table.Contains(word)) {
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
} else {
|
||||
|
||||
@@ -14,10 +14,10 @@ namespace sherpa_onnx {
|
||||
static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OfflineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const OfflineCtcFstDecoderConfig &, const std::string &,
|
||||
int32_t, const std::string &, float, float>(),
|
||||
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
||||
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
|
||||
@@ -25,6 +25,7 @@ Args:
|
||||
static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
||||
using PyClass = OfflineRecognitionResult;
|
||||
py::class_<PyClass>(*m, "OfflineRecognitionResult")
|
||||
.def("__str__", &PyClass::AsJsonString)
|
||||
.def_property_readonly(
|
||||
"text",
|
||||
[](const PyClass &self) -> py::str {
|
||||
@@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
||||
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
||||
}
|
||||
|
||||
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = OfflineFeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
|
||||
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindOfflineStream(py::module *m) {
|
||||
PybindOfflineFeatureExtractorConfig(m);
|
||||
PybindOfflineRecognitionResult(m);
|
||||
|
||||
using PyClass = OfflineStream;
|
||||
|
||||
@@ -4,8 +4,8 @@ from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
FeatureExtractorConfig,
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineFeatureExtractorConfig,
|
||||
OfflineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
@@ -51,6 +51,7 @@ class OfflineRecognizer(object):
|
||||
blank_penalty: float = 0.0,
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
model_type: str = "transducer",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -106,10 +107,10 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="transducer",
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
@@ -182,7 +183,7 @@ class OfflineRecognizer(object):
|
||||
model_type="paraformer",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
@@ -246,7 +247,7 @@ class OfflineRecognizer(object):
|
||||
model_type="nemo_ctc",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
@@ -326,7 +327,7 @@ class OfflineRecognizer(object):
|
||||
model_type="whisper",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=16000,
|
||||
feature_dim=80,
|
||||
)
|
||||
@@ -389,7 +390,7 @@ class OfflineRecognizer(object):
|
||||
model_type="tdnn",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
@@ -453,7 +454,7 @@ class OfflineRecognizer(object):
|
||||
model_type="wenet_ctc",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user