Support non-streaming WeNet CTC models. (#426)
This commit is contained in:
41
.github/scripts/test-offline-ctc.sh
vendored
41
.github/scripts/test-offline-ctc.sh
vendored
@@ -13,6 +13,47 @@ echo "PATH: $PATH"
|
|||||||
|
|
||||||
which $EXE
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run Wenet models"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
wenet_models=(
|
||||||
|
sherpa-onnx-zh-wenet-aishell
|
||||||
|
sherpa-onnx-zh-wenet-aishell2
|
||||||
|
sherpa-onnx-zh-wenet-wenetspeech
|
||||||
|
sherpa-onnx-zh-wenet-multi-cn
|
||||||
|
sherpa-onnx-en-wenet-librispeech
|
||||||
|
sherpa-onnx-en-wenet-gigaspeech
|
||||||
|
)
|
||||||
|
for name in ${wenet_models[@]}; do
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/$name
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "test float32 models"
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--wenet-ctc-model=$repo/model.onnx \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
log "test int8 models"
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--wenet-ctc-model=$repo/model.int8.onnx \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
done
|
||||||
|
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
log "Run tdnn yesno (Hebrew)"
|
log "Run tdnn yesno (Hebrew)"
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
|
|||||||
11
.github/workflows/export-wenet-to-onnx.yaml
vendored
11
.github/workflows/export-wenet-to-onnx.yaml
vendored
@@ -1,17 +1,6 @@
|
|||||||
name: export-wenet-to-onnx
|
name: export-wenet-to-onnx
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
paths:
|
|
||||||
- 'scripts/wenet/**'
|
|
||||||
- '.github/workflows/export-wenet-to-onnx.yaml'
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- 'scripts/wenet/**'
|
|
||||||
- '.github/workflows/export-wenet-to-onnx.yaml'
|
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
16
.github/workflows/linux.yaml
vendored
16
.github/workflows/linux.yaml
vendored
@@ -89,6 +89,14 @@ jobs:
|
|||||||
file build/bin/sherpa-onnx
|
file build/bin/sherpa-onnx
|
||||||
readelf -d build/bin/sherpa-onnx
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline CTC
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-ctc.sh
|
||||||
|
|
||||||
- name: Test offline TTS
|
- name: Test offline TTS
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -115,14 +123,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-offline-whisper.sh
|
.github/scripts/test-offline-whisper.sh
|
||||||
|
|
||||||
- name: Test offline CTC
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
export PATH=$PWD/build/bin:$PATH
|
|
||||||
export EXE=sherpa-onnx-offline
|
|
||||||
|
|
||||||
.github/scripts/test-offline-ctc.sh
|
|
||||||
|
|
||||||
- name: Test offline transducer
|
- name: Test offline transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ def main():
|
|||||||
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
||||||
url = os.environ.get("WENET_URL", "")
|
url = os.environ.get("WENET_URL", "")
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "wenet-ctc",
|
"model_type": "wenet_ctc",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "wenet",
|
"model_author": "wenet",
|
||||||
"comment": "streaming",
|
"comment": "streaming",
|
||||||
@@ -185,6 +185,7 @@ def main():
|
|||||||
"cnn_module_kernel": cnn_module_kernel,
|
"cnn_module_kernel": cnn_module_kernel,
|
||||||
"right_context": right_context,
|
"right_context": right_context,
|
||||||
"subsampling_factor": subsampling_factor,
|
"subsampling_factor": subsampling_factor,
|
||||||
|
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
|
||||||
}
|
}
|
||||||
add_meta_data(filename=filename, meta_data=meta_data)
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|||||||
@@ -107,10 +107,12 @@ def main():
|
|||||||
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
||||||
url = os.environ.get("WENET_URL", "")
|
url = os.environ.get("WENET_URL", "")
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "wenet-ctc",
|
"model_type": "wenet_ctc",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "wenet",
|
"model_author": "wenet",
|
||||||
"comment": "non-streaming",
|
"comment": "non-streaming",
|
||||||
|
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
|
||||||
|
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
|
||||||
"url": url,
|
"url": url,
|
||||||
}
|
}
|
||||||
add_meta_data(filename=filename, meta_data=meta_data)
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ set(sources
|
|||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-transducer-model.cc
|
offline-transducer-model.cc
|
||||||
offline-transducer-modified-beam-search-decoder.cc
|
offline-transducer-modified-beam-search-decoder.cc
|
||||||
|
offline-wenet-ctc-model-config.cc
|
||||||
|
offline-wenet-ctc-model.cc
|
||||||
offline-whisper-greedy-search-decoder.cc
|
offline-whisper-greedy-search-decoder.cc
|
||||||
offline-whisper-model-config.cc
|
offline-whisper-model-config.cc
|
||||||
offline-whisper-model.cc
|
offline-whisper-model.cc
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
@@ -21,10 +22,11 @@ enum class ModelType {
|
|||||||
kEncDecCTCModelBPE,
|
kEncDecCTCModelBPE,
|
||||||
kTdnn,
|
kTdnn,
|
||||||
kZipformerCtc,
|
kZipformerCtc,
|
||||||
|
kWenetCtc,
|
||||||
kUnkown,
|
kUnkown,
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
"If you are using models from NeMo, please refer to\n"
|
"If you are using models from NeMo, please refer to\n"
|
||||||
"https://huggingface.co/csukuangfj/"
|
"https://huggingface.co/csukuangfj/"
|
||||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||||
|
"If you are using models from WeNet, please refer to\n"
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
|
||||||
|
"run.sh\n"
|
||||||
"\n"
|
"\n"
|
||||||
"for how to add metadta to model.onnx\n");
|
"for how to add metadta to model.onnx\n");
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
@@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kTdnn;
|
return ModelType::kTdnn;
|
||||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
||||||
return ModelType::kZipformerCtc;
|
return ModelType::kZipformerCtc;
|
||||||
|
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||||
|
return ModelType::kWenetCtc;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
@@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
filename = config.tdnn.model;
|
filename = config.tdnn.model;
|
||||||
} else if (!config.zipformer_ctc.model.empty()) {
|
} else if (!config.zipformer_ctc.model.empty()) {
|
||||||
filename = config.zipformer_ctc.model;
|
filename = config.zipformer_ctc.model;
|
||||||
|
} else if (!config.wenet_ctc.model.empty()) {
|
||||||
|
filename = config.wenet_ctc.model;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
case ModelType::kZipformerCtc:
|
case ModelType::kZipformerCtc:
|
||||||
return std::make_unique<OfflineZipformerCtcModel>(config);
|
return std::make_unique<OfflineZipformerCtcModel>(config);
|
||||||
break;
|
break;
|
||||||
|
case ModelType::kWenetCtc:
|
||||||
|
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||||
|
break;
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
filename = config.tdnn.model;
|
filename = config.tdnn.model;
|
||||||
} else if (!config.zipformer_ctc.model.empty()) {
|
} else if (!config.zipformer_ctc.model.empty()) {
|
||||||
filename = config.zipformer_ctc.model;
|
filename = config.zipformer_ctc.model;
|
||||||
|
} else if (!config.wenet_ctc.model.empty()) {
|
||||||
|
filename = config.wenet_ctc.model;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
case ModelType::kZipformerCtc:
|
case ModelType::kZipformerCtc:
|
||||||
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
|
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
|
||||||
break;
|
break;
|
||||||
|
case ModelType::kWenetCtc:
|
||||||
|
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||||
|
break;
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ class OfflineCtcModel {
|
|||||||
* for the features.
|
* for the features.
|
||||||
*/
|
*/
|
||||||
virtual std::string FeatureNormalizationMethod() const { return {}; }
|
virtual std::string FeatureNormalizationMethod() const { return {}; }
|
||||||
|
|
||||||
|
// Return true if the model supports batch size > 1
|
||||||
|
virtual bool SupportBatchProcessing() const { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
whisper.Register(po);
|
whisper.Register(po);
|
||||||
tdnn.Register(po);
|
tdnn.Register(po);
|
||||||
zipformer_ctc.Register(po);
|
zipformer_ctc.Register(po);
|
||||||
|
wenet_ctc.Register(po);
|
||||||
|
|
||||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||||
|
|
||||||
@@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return zipformer_ctc.Validate();
|
return zipformer_ctc.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!wenet_ctc.model.empty()) {
|
||||||
|
return wenet_ctc.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
return transducer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "whisper=" << whisper.ToString() << ", ";
|
os << "whisper=" << whisper.ToString() << ", ";
|
||||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||||
|
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineWhisperModelConfig whisper;
|
OfflineWhisperModelConfig whisper;
|
||||||
OfflineTdnnModelConfig tdnn;
|
OfflineTdnnModelConfig tdnn;
|
||||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||||
|
OfflineWenetCtcModelConfig wenet_ctc;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 2;
|
int32_t num_threads = 2;
|
||||||
@@ -46,6 +48,7 @@ struct OfflineModelConfig {
|
|||||||
const OfflineWhisperModelConfig &whisper,
|
const OfflineWhisperModelConfig &whisper,
|
||||||
const OfflineTdnnModelConfig &tdnn,
|
const OfflineTdnnModelConfig &tdnn,
|
||||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||||
|
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type)
|
const std::string &provider, const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
@@ -54,6 +57,7 @@ struct OfflineModelConfig {
|
|||||||
whisper(whisper),
|
whisper(whisper),
|
||||||
tdnn(tdnn),
|
tdnn(tdnn),
|
||||||
zipformer_ctc(zipformer_ctc),
|
zipformer_ctc(zipformer_ctc),
|
||||||
|
wenet_ctc(wenet_ctc),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
|
|||||||
@@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
void Init() {
|
void Init() {
|
||||||
|
if (!config_.model_config.wenet_ctc.model.empty()) {
|
||||||
|
// WeNet CTC models assume input samples are in the range
|
||||||
|
// [-32768, 32767], so we set normalize_samples to false
|
||||||
|
config_.feat_config.normalize_samples = false;
|
||||||
|
}
|
||||||
|
|
||||||
config_.feat_config.nemo_normalize_type =
|
config_.feat_config.nemo_normalize_type =
|
||||||
model_->FeatureNormalizationMethod();
|
model_->FeatureNormalizationMethod();
|
||||||
|
|
||||||
@@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
config_.ctc_fst_decoder_config);
|
config_.ctc_fst_decoder_config);
|
||||||
} else if (config_.decoding_method == "greedy_search") {
|
} else if (config_.decoding_method == "greedy_search") {
|
||||||
if (!symbol_table_.contains("<blk>") &&
|
if (!symbol_table_.contains("<blk>") &&
|
||||||
!symbol_table_.contains("<eps>")) {
|
!symbol_table_.contains("<eps>") &&
|
||||||
|
!symbol_table_.contains("<blank>")) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"We expect that tokens.txt contains "
|
"We expect that tokens.txt contains "
|
||||||
"the symbol <blk> or <eps> and its ID.");
|
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
} else if (symbol_table_.contains("<eps>")) {
|
} else if (symbol_table_.contains("<eps>")) {
|
||||||
// for tdnn models of the yesno recipe from icefall
|
// for tdnn models of the yesno recipe from icefall
|
||||||
blank_id = symbol_table_["<eps>"];
|
blank_id = symbol_table_["<eps>"];
|
||||||
|
} else if (symbol_table_.contains("<blank>")) {
|
||||||
|
// for Wenet CTC models
|
||||||
|
blank_id = symbol_table_["<blank>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||||
@@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
if (!model_->SupportBatchProcessing()) {
|
||||||
|
// If the model does not support batch process,
|
||||||
|
// we process each stream independently.
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(ss[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
@@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Decode a single stream.
|
||||||
|
// Some models do not support batch size > 1, e.g., WeNet CTC models.
|
||||||
|
void DecodeStream(OfflineStream *s) const {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
|
||||||
|
int64_t x_length_scalar = num_frames;
|
||||||
|
std::array<int64_t, 1> x_length_shape = {1};
|
||||||
|
Ort::Value x_length =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
|
||||||
|
x_length_shape.data(), x_length_shape.size());
|
||||||
|
|
||||||
|
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||||
|
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
|
||||||
|
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
||||||
|
model_->SubsamplingFactor());
|
||||||
|
s->SetResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineRecognizerConfig config_;
|
OfflineRecognizerConfig config_;
|
||||||
SymbolTable symbol_table_;
|
SymbolTable symbol_table_;
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
} else if (model_type == "paraformer") {
|
} else if (model_type == "paraformer") {
|
||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
} else if (model_type == "whisper") {
|
} else if (model_type == "whisper") {
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
@@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_filename = config.model_config.tdnn.model;
|
model_filename = config.model_config.tdnn.model;
|
||||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||||
model_filename = config.model_config.zipformer_ctc.model;
|
model_filename = config.model_config.zipformer_ctc.model;
|
||||||
|
} else if (!config.model_config.wenet_ctc.model.empty()) {
|
||||||
|
model_filename = config.model_config.wenet_ctc.model;
|
||||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||||
model_filename = config.model_config.whisper.encoder;
|
model_filename = config.model_config.whisper.encoder;
|
||||||
} else {
|
} else {
|
||||||
@@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||||
"zipformer/export-onnx-ctc.py"
|
"zipformer/export-onnx-ctc.py"
|
||||||
"\n"
|
"\n"
|
||||||
|
"(6) CTC models from WeNet"
|
||||||
|
"\n "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||||
|
"\n"
|
||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
@@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
" - EncDecCTCModelBPE models from NeMo\n"
|
" - EncDecCTCModelBPE models from NeMo\n"
|
||||||
" - Whisper models\n"
|
" - Whisper models\n"
|
||||||
" - Tdnn models\n"
|
" - Tdnn models\n"
|
||||||
" - Zipformer CTC models\n",
|
" - Zipformer CTC models\n"
|
||||||
|
" - WeNet CTC models\n",
|
||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
} else if (model_type == "paraformer") {
|
} else if (model_type == "paraformer") {
|
||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||||
} else if (model_type == "whisper") {
|
} else if (model_type == "whisper") {
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||||
@@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_filename = config.model_config.tdnn.model;
|
model_filename = config.model_config.tdnn.model;
|
||||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||||
model_filename = config.model_config.zipformer_ctc.model;
|
model_filename = config.model_config.zipformer_ctc.model;
|
||||||
|
} else if (!config.model_config.wenet_ctc.model.empty()) {
|
||||||
|
model_filename = config.model_config.wenet_ctc.model;
|
||||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||||
model_filename = config.model_config.whisper.encoder;
|
model_filename = config.model_config.whisper.encoder;
|
||||||
} else {
|
} else {
|
||||||
@@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||||
"zipformer/export-onnx-ctc.py"
|
"zipformer/export-onnx-ctc.py"
|
||||||
"\n"
|
"\n"
|
||||||
|
"(6) CTC models from WeNet"
|
||||||
|
"\n "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||||
|
"\n"
|
||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
@@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||||
model_type == "zipformer2_ctc") {
|
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
" - EncDecCTCModelBPE models from NeMo\n"
|
" - EncDecCTCModelBPE models from NeMo\n"
|
||||||
" - Whisper models\n"
|
" - Whisper models\n"
|
||||||
" - Tdnn models\n"
|
" - Tdnn models\n"
|
||||||
" - Zipformer CTC models\n",
|
" - Zipformer CTC models\n"
|
||||||
|
" - WeNet CTC models\n",
|
||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig {
|
|||||||
// Feature dimension
|
// Feature dimension
|
||||||
int32_t feature_dim = 80;
|
int32_t feature_dim = 80;
|
||||||
|
|
||||||
// Set internally by some models, e.g., paraformer sets it to false.
|
// Set internally by some models, e.g., paraformer and wenet CTC models set
|
||||||
|
// it to false.
|
||||||
// This parameter is not exposed to users from the commandline
|
// This parameter is not exposed to users from the commandline
|
||||||
// If true, the feature extractor expects inputs to be normalized to
|
// If true, the feature extractor expects inputs to be normalized to
|
||||||
// the range [-1, 1].
|
// the range [-1, 1].
|
||||||
|
|||||||
37
sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
Normal file
37
sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineWenetCtcModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register(
|
||||||
|
"wenet-ctc-model", &model,
|
||||||
|
"Path to model.onnx from WeNet. Please see "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineWenetCtcModelConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineWenetCtcModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineWenetCtcModelConfig(";
|
||||||
|
os << "model=\"" << model << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
28
sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineWenetCtcModelConfig {
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
OfflineWenetCtcModelConfig() = default;
|
||||||
|
explicit OfflineWenetCtcModelConfig(const std::string &model)
|
||||||
|
: model(model) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
118
sherpa-onnx/csrc/offline-wenet-ctc-model.cc
Normal file
118
sherpa-onnx/csrc/offline-wenet-ctc-model.cc
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-wenet-ctc-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineWenetCtcModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(config_.wenet_ctc.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(mgr, config_.wenet_ctc.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||||
|
Ort::Value features_length) {
|
||||||
|
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||||
|
std::move(features_length)};
|
||||||
|
|
||||||
|
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
int32_t subsampling_factor_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr,
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OfflineWenetCtcModel::~OfflineWenetCtcModel() = default;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OfflineWenetCtcModel::Forward(
|
||||||
|
Ort::Value features, Ort::Value features_length) {
|
||||||
|
return impl_->Forward(std::move(features), std::move(features_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); }
|
||||||
|
|
||||||
|
int32_t OfflineWenetCtcModel::SubsamplingFactor() const {
|
||||||
|
return impl_->SubsamplingFactor();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineWenetCtcModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
79
sherpa-onnx/csrc/offline-wenet-ctc-model.h
Normal file
79
sherpa-onnx/csrc/offline-wenet-ctc-model.h
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-wenet-ctc-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** This class implements the CTC model from WeNet.
|
||||||
|
*
|
||||||
|
* See
|
||||||
|
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py
|
||||||
|
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py
|
||||||
|
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class OfflineWenetCtcModel : public OfflineCtcModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineWenetCtcModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OfflineWenetCtcModel() override;
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C).
|
||||||
|
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||||
|
* valid frames in `features` before padding.
|
||||||
|
* Its dtype is int64_t.
|
||||||
|
*
|
||||||
|
* @return Return a vector containing:
|
||||||
|
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||||
|
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||||
|
*/
|
||||||
|
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||||
|
Ort::Value features_length) override;
|
||||||
|
|
||||||
|
/** Return the vocabulary size of the model
|
||||||
|
*/
|
||||||
|
int32_t VocabSize() const override;
|
||||||
|
|
||||||
|
/** SubsamplingFactor of the model
|
||||||
|
*
|
||||||
|
* For Citrinet, the subsampling factor is usually 4.
|
||||||
|
* For Conformer CTC, the subsampling factor is usually 8.
|
||||||
|
*/
|
||||||
|
int32_t SubsamplingFactor() const override;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const override;
|
||||||
|
|
||||||
|
// WeNet CTC models do not support batch size > 1
|
||||||
|
bool SupportBatchProcessing() const override { return false; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||||
@@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
offline-tts-model-config.cc
|
offline-tts-model-config.cc
|
||||||
offline-tts-vits-model-config.cc
|
offline-tts-vits-model-config.cc
|
||||||
offline-tts.cc
|
offline-tts.cc
|
||||||
|
offline-wenet-ctc-model-config.cc
|
||||||
offline-whisper-model-config.cc
|
offline-whisper-model-config.cc
|
||||||
offline-zipformer-ctc-model-config.cc
|
offline-zipformer-ctc-model-config.cc
|
||||||
online-lm-config.cc
|
online-lm-config.cc
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineWhisperModelConfig(m);
|
PybindOfflineWhisperModelConfig(m);
|
||||||
PybindOfflineTdnnModelConfig(m);
|
PybindOfflineTdnnModelConfig(m);
|
||||||
PybindOfflineZipformerCtcModelConfig(m);
|
PybindOfflineZipformerCtcModelConfig(m);
|
||||||
|
PybindOfflineWenetCtcModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
@@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
const OfflineNemoEncDecCtcModelConfig &,
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
const OfflineWhisperModelConfig &,
|
const OfflineWhisperModelConfig &,
|
||||||
const OfflineTdnnModelConfig &,
|
const OfflineTdnnModelConfig &,
|
||||||
const OfflineZipformerCtcModelConfig &, const std::string &,
|
const OfflineZipformerCtcModelConfig &,
|
||||||
|
const OfflineWenetCtcModelConfig &, const std::string &,
|
||||||
int32_t, bool, const std::string &, const std::string &>(),
|
int32_t, bool, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||||
@@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
py::arg("whisper") = OfflineWhisperModelConfig(),
|
py::arg("whisper") = OfflineWhisperModelConfig(),
|
||||||
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
||||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||||
|
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
@@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
.def_readwrite("whisper", &PyClass::whisper)
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
.def_readwrite("tdnn", &PyClass::tdnn)
|
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||||
|
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
22
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc
Normal file
22
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-wenet-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineWenetCtcModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineWenetCtcModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineWenetCtcModelConfig")
|
||||||
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-wenet-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineWenetCtcModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||||
Reference in New Issue
Block a user