Support non-streaming WeNet CTC models. (#426)
This commit is contained in:
@@ -41,6 +41,8 @@ set(sources
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-wenet-ctc-model.cc
|
||||
offline-whisper-greedy-search-decoder.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-whisper-model.cc
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -21,10 +22,11 @@ enum class ModelType {
|
||||
kEncDecCTCModelBPE,
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"If you are using models from NeMo, please refer to\n"
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||
"If you are using models from WeNet, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
|
||||
"run.sh\n"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnkown;
|
||||
@@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kTdnn;
|
||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
||||
return ModelType::kZipformerCtc;
|
||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||
return ModelType::kWenetCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
@@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
@@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
|
||||
@@ -63,6 +63,9 @@ class OfflineCtcModel {
|
||||
* for the features.
|
||||
*/
|
||||
virtual std::string FeatureNormalizationMethod() const { return {}; }
|
||||
|
||||
// Return true if the model supports batch size > 1
|
||||
virtual bool SupportBatchProcessing() const { return true; }
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
whisper.Register(po);
|
||||
tdnn.Register(po);
|
||||
zipformer_ctc.Register(po);
|
||||
wenet_ctc.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
@@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return zipformer_ctc.Validate();
|
||||
}
|
||||
|
||||
if (!wenet_ctc.model.empty()) {
|
||||
return wenet_ctc.Validate();
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
@@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "whisper=" << whisper.ToString() << ", ";
|
||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
@@ -22,6 +23,7 @@ struct OfflineModelConfig {
|
||||
OfflineWhisperModelConfig whisper;
|
||||
OfflineTdnnModelConfig tdnn;
|
||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||
OfflineWenetCtcModelConfig wenet_ctc;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
@@ -46,6 +48,7 @@ struct OfflineModelConfig {
|
||||
const OfflineWhisperModelConfig &whisper,
|
||||
const OfflineTdnnModelConfig &tdnn,
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
@@ -54,6 +57,7 @@ struct OfflineModelConfig {
|
||||
whisper(whisper),
|
||||
tdnn(tdnn),
|
||||
zipformer_ctc(zipformer_ctc),
|
||||
wenet_ctc(wenet_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
|
||||
@@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
#endif
|
||||
|
||||
void Init() {
|
||||
if (!config_.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
@@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
config_.ctc_fst_decoder_config);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>") &&
|
||||
!symbol_table_.contains("<eps>")) {
|
||||
!symbol_table_.contains("<eps>") &&
|
||||
!symbol_table_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> and its ID.");
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
} else if (symbol_table_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = symbol_table_["<eps>"];
|
||||
} else if (symbol_table_.contains("<blank>")) {
|
||||
// for Wenet CTC models
|
||||
blank_id = symbol_table_["<blank>"];
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||
@@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
if (!model_->SupportBatchProcessing()) {
|
||||
// If the model does not support batch process,
|
||||
// we process each stream independently.
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Decode a single stream.
|
||||
// Some models do not support batch size > 1, e.g., WeNet CTC models.
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
int64_t x_length_scalar = num_frames;
|
||||
std::array<int64_t, 1> x_length_shape = {1};
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
|
||||
x_length_shape.data(), x_length_shape.size());
|
||||
|
||||
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||
int32_t frame_shift_ms = 10;
|
||||
|
||||
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
|
||||
@@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
@@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||
model_filename = config.model_config.zipformer_ctc.model;
|
||||
} else if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
model_filename = config.model_config.wenet_ctc.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
@@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"zipformer/export-onnx-ctc.py"
|
||||
"\n"
|
||||
"(6) CTC models from WeNet"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
@@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n",
|
||||
" - Zipformer CTC models\n"
|
||||
" - WeNet CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
@@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
@@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||
model_filename = config.model_config.zipformer_ctc.model;
|
||||
} else if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
model_filename = config.model_config.wenet_ctc.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
@@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"zipformer/export-onnx-ctc.py"
|
||||
"\n"
|
||||
"(6) CTC models from WeNet"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
@@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n",
|
||||
" - Zipformer CTC models\n"
|
||||
" - WeNet CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
|
||||
@@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig {
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// Set internally by some models, e.g., paraformer sets it to false.
|
||||
// Set internally by some models, e.g., paraformer and wenet CTC models set
|
||||
// it to false.
|
||||
// This parameter is not exposed to users from the commandline
|
||||
// If true, the feature extractor expects inputs to be normalized to
|
||||
// the range [-1, 1].
|
||||
|
||||
37
sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
Normal file
37
sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
Normal file
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineWenetCtcModelConfig::Register(ParseOptions *po) {
|
||||
po->Register(
|
||||
"wenet-ctc-model", &model,
|
||||
"Path to model.onnx from WeNet. Please see "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models");
|
||||
}
|
||||
|
||||
bool OfflineWenetCtcModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineWenetCtcModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineWenetCtcModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineWenetCtcModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineWenetCtcModelConfig() = default;
|
||||
explicit OfflineWenetCtcModelConfig(const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
118
sherpa-onnx/csrc/offline-wenet-ctc-model.cc
Normal file
118
sherpa-onnx/csrc/offline-wenet-ctc-model.cc
Normal file
@@ -0,0 +1,118 @@
|
||||
// sherpa-onnx/csrc/offline-wenet-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWenetCtcModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.wenet_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.wenet_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0;
|
||||
int32_t subsampling_factor_ = 0;
|
||||
};
|
||||
|
||||
OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineWenetCtcModel::~OfflineWenetCtcModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineWenetCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OfflineWenetCtcModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineWenetCtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
79
sherpa-onnx/csrc/offline-wenet-ctc-model.h
Normal file
79
sherpa-onnx/csrc/offline-wenet-ctc-model.h
Normal file
@@ -0,0 +1,79 @@
|
||||
// sherpa-onnx/csrc/offline-wenet-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the CTC model from WeNet.
|
||||
*
|
||||
* See
|
||||
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py
|
||||
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py
|
||||
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh
|
||||
*
|
||||
*/
|
||||
class OfflineWenetCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineWenetCtcModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineWenetCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const override;
|
||||
|
||||
/** SubsamplingFactor of the model
|
||||
*
|
||||
* For Citrinet, the subsampling factor is usually 4.
|
||||
* For Conformer CTC, the subsampling factor is usually 8.
|
||||
*/
|
||||
int32_t SubsamplingFactor() const override;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const override;
|
||||
|
||||
// WeNet CTC models do not support batch size > 1
|
||||
bool SupportBatchProcessing() const override { return false; }
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
|
||||
@@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
offline-tts-model-config.cc
|
||||
offline-tts-vits-model-config.cc
|
||||
offline-tts.cc
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
online-lm-config.cc
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
@@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineWhisperModelConfig(m);
|
||||
PybindOfflineTdnnModelConfig(m);
|
||||
PybindOfflineZipformerCtcModelConfig(m);
|
||||
PybindOfflineWenetCtcModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
@@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const OfflineWhisperModelConfig &,
|
||||
const OfflineTdnnModelConfig &,
|
||||
const OfflineZipformerCtcModelConfig &, const std::string &,
|
||||
const OfflineZipformerCtcModelConfig &,
|
||||
const OfflineWenetCtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
@@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
py::arg("whisper") = OfflineWhisperModelConfig(),
|
||||
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
@@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
|
||||
22
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc
Normal file
22
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc
Normal file
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/offline-wenet-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineWenetCtcModelConfig(py::module *m) {
|
||||
using PyClass = OfflineWenetCtcModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineWenetCtcModelConfig")
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-wenet-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineWenetCtcModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
|
||||
Reference in New Issue
Block a user