Add C++ runtime for Tele-AI/TeleSpeech-ASR (#970)
This commit is contained in:
@@ -39,6 +39,7 @@ set(sources
|
||||
offline-stream.cc
|
||||
offline-tdnn-ctc-model.cc
|
||||
offline-tdnn-model-config.cc
|
||||
offline-telespeech-ctc-model.cc
|
||||
offline-transducer-greedy-search-decoder.cc
|
||||
offline-transducer-greedy-search-nemo-decoder.cc
|
||||
offline-transducer-model-config.cc
|
||||
|
||||
@@ -56,22 +56,11 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
class FeatureExtractor::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||
opts_.frame_opts.dither = config.dither;
|
||||
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||
opts_.frame_opts.window_type = config.window_type;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
opts_.mel_opts.high_freq = config.high_freq;
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
|
||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
if (config_.is_mfcc) {
|
||||
InitMfcc();
|
||||
} else {
|
||||
InitFbank();
|
||||
}
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
@@ -101,35 +90,48 @@ class FeatureExtractor::Impl {
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (sampling_rate != opts_.frame_opts.samp_freq) {
|
||||
if (sampling_rate != config_.sampling_rate) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
|
||||
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
|
||||
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
|
||||
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
resampler_ = std::make_unique<LinearResample>(
|
||||
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
|
||||
sampling_rate, config_.sampling_rate, lowpass_cutoff,
|
||||
lowpass_filter_width);
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
@@ -179,11 +181,56 @@ class FeatureExtractor::Impl {
|
||||
return features;
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
int32_t FeatureDim() const {
|
||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitFbank() {
|
||||
opts_.frame_opts.dither = config_.dither;
|
||||
opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config_.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
|
||||
opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
|
||||
opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
|
||||
opts_.frame_opts.window_type = config_.window_type;
|
||||
|
||||
opts_.mel_opts.num_bins = config_.feature_dim;
|
||||
|
||||
opts_.mel_opts.high_freq = config_.high_freq;
|
||||
opts_.mel_opts.low_freq = config_.low_freq;
|
||||
|
||||
opts_.mel_opts.is_librosa = config_.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
void InitMfcc() {
|
||||
mfcc_opts_.frame_opts.dither = config_.dither;
|
||||
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
|
||||
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
|
||||
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
|
||||
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
|
||||
mfcc_opts_.frame_opts.window_type = config_.window_type;
|
||||
|
||||
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
|
||||
|
||||
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
|
||||
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
|
||||
|
||||
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
|
||||
|
||||
mfcc_opts_.num_ceps = config_.num_ceps;
|
||||
mfcc_opts_.use_energy = config_.use_energy;
|
||||
|
||||
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||
knf::FbankOptions opts_;
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
FeatureExtractorConfig config_;
|
||||
mutable std::mutex mutex_;
|
||||
std::unique_ptr<LinearResample> resampler_;
|
||||
|
||||
@@ -18,7 +18,10 @@ struct FeatureExtractorConfig {
|
||||
// the sampling rate of the input waveform, we will do resampling inside.
|
||||
int32_t sampling_rate = 16000;
|
||||
|
||||
// Feature dimension
|
||||
// num_mel_bins
|
||||
//
|
||||
// Note: for mfcc, this value is also for num_mel_bins.
|
||||
// The actual feature dimension is actuall num_ceps
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// minimal frequency for Mel-filterbank, in Hz
|
||||
@@ -69,6 +72,12 @@ struct FeatureExtractorConfig {
|
||||
// for details
|
||||
std::string nemo_normalize_type;
|
||||
|
||||
// for MFCC
|
||||
int32_t num_ceps = 13;
|
||||
bool use_energy = true;
|
||||
|
||||
bool is_mfcc = false;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -24,6 +25,7 @@ enum class ModelType {
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kTeleSpeechCtc,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
@@ -63,6 +65,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"If you are using models from WeNet, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
|
||||
"run.sh\n"
|
||||
"If you are using models from TeleSpeech, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/"
|
||||
"add-metadata.py"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnknown;
|
||||
@@ -78,6 +83,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kZipformerCtc;
|
||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||
return ModelType::kWenetCtc;
|
||||
} else if (model_type.get() == std::string("telespeech_ctc")) {
|
||||
return ModelType::kTeleSpeechCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnknown;
|
||||
@@ -97,6 +104,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else if (!config.telespeech_ctc.empty()) {
|
||||
filename = config.telespeech_ctc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -124,6 +133,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kTeleSpeechCtc:
|
||||
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
@@ -147,6 +159,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else if (!config.telespeech_ctc.empty()) {
|
||||
filename = config.telespeech_ctc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -175,6 +189,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kTeleSpeechCtc:
|
||||
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
|
||||
@@ -19,6 +19,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
zipformer_ctc.Register(po);
|
||||
wenet_ctc.Register(po);
|
||||
|
||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||
"Path to model.onnx for telespeech ctc");
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
@@ -33,7 +36,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||
"tdnn, zipformer2_ctc"
|
||||
"tdnn, zipformer2_ctc, telespeech_ctc."
|
||||
"All other values lead to loading the model twice.");
|
||||
po->Register("modeling-unit", &modeling_unit,
|
||||
"The modeling unit of the model, commonly used units are bpe, "
|
||||
@@ -55,14 +58,14 @@ bool OfflineModelConfig::Validate() const {
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
|
||||
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -91,6 +94,14 @@ bool OfflineModelConfig::Validate() const {
|
||||
return wenet_ctc.Validate();
|
||||
}
|
||||
|
||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||
telespeech_ctc.c_str());
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
@@ -105,6 +116,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
|
||||
@@ -24,6 +24,7 @@ struct OfflineModelConfig {
|
||||
OfflineTdnnModelConfig tdnn;
|
||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||
OfflineWenetCtcModelConfig wenet_ctc;
|
||||
std::string telespeech_ctc;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
@@ -52,6 +53,7 @@ struct OfflineModelConfig {
|
||||
const OfflineTdnnModelConfig &tdnn,
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||
const std::string &telespeech_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
@@ -63,6 +65,7 @@ struct OfflineModelConfig {
|
||||
tdnn(tdnn),
|
||||
zipformer_ctc(zipformer_ctc),
|
||||
wenet_ctc(wenet_ctc),
|
||||
telespeech_ctc(telespeech_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
|
||||
@@ -88,6 +88,17 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
#endif
|
||||
|
||||
void Init() {
|
||||
if (!config_.model_config.telespeech_ctc.empty()) {
|
||||
config_.feat_config.snip_edges = true;
|
||||
config_.feat_config.num_ceps = 40;
|
||||
config_.feat_config.feature_dim = 40;
|
||||
config_.feat_config.low_freq = 40;
|
||||
config_.feat_config.high_freq = -200;
|
||||
config_.feat_config.use_energy = false;
|
||||
config_.feat_config.normalize_samples = false;
|
||||
config_.feat_config.is_mfcc = true;
|
||||
}
|
||||
|
||||
if (!config_.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
|
||||
@@ -29,7 +29,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
|
||||
model_type == "telespeech_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
@@ -53,6 +54,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.paraformer.model;
|
||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else if (!config.model_config.telespeech_ctc.empty()) {
|
||||
model_filename = config.model_config.telespeech_ctc;
|
||||
} else if (!config.model_config.tdnn.model.empty()) {
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||
@@ -111,6 +114,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||
"\n"
|
||||
"(7) CTC models from TeleSpeech"
|
||||
"\n "
|
||||
"https://github.com/Tele-AI/TeleSpeech-ASR"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -133,7 +140,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" ||
|
||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
|
||||
model_type == "telespeech_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
@@ -151,7 +159,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n"
|
||||
" - WeNet CTC models\n",
|
||||
" - WeNet CTC models\n"
|
||||
" - TeleSpeech CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
@@ -169,7 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
|
||||
model_type == "telespeech_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
@@ -199,6 +209,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.zipformer_ctc.model;
|
||||
} else if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
model_filename = config.model_config.wenet_ctc.model;
|
||||
} else if (!config.model_config.telespeech_ctc.empty()) {
|
||||
model_filename = config.model_config.telespeech_ctc;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
@@ -251,6 +263,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
|
||||
"\n"
|
||||
"(7) CTC models from TeleSpeech"
|
||||
"\n "
|
||||
"https://github.com/Tele-AI/TeleSpeech-ASR"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -273,7 +289,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE" ||
|
||||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
|
||||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
|
||||
model_type == "telespeech_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
@@ -291,7 +308,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n"
|
||||
" - WeNet CTC models\n",
|
||||
" - WeNet CTC models\n"
|
||||
" - TeleSpeech CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
|
||||
@@ -57,22 +57,44 @@ class OfflineStream::Impl {
|
||||
explicit Impl(const FeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: config_(config), context_graph_(context_graph) {
|
||||
opts_.frame_opts.dither = config.dither;
|
||||
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||
opts_.frame_opts.window_type = config.window_type;
|
||||
if (config.is_mfcc) {
|
||||
mfcc_opts_.frame_opts.dither = config_.dither;
|
||||
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
|
||||
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
|
||||
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
|
||||
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
|
||||
mfcc_opts_.frame_opts.window_type = config_.window_type;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
|
||||
|
||||
opts_.mel_opts.high_freq = config.high_freq;
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
|
||||
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
|
||||
|
||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
mfcc_opts_.num_ceps = config_.num_ceps;
|
||||
mfcc_opts_.use_energy = config_.use_energy;
|
||||
|
||||
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
||||
} else {
|
||||
opts_.frame_opts.dither = config.dither;
|
||||
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||
opts_.frame_opts.window_type = config.window_type;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
opts_.mel_opts.high_freq = config.high_freq;
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
|
||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(WhisperTag /*tag*/) {
|
||||
@@ -81,6 +103,7 @@ class OfflineStream::Impl {
|
||||
opts_.mel_opts.num_bins = 80; // not used
|
||||
whisper_fbank_ =
|
||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
}
|
||||
|
||||
explicit Impl(CEDTag /*tag*/) {
|
||||
@@ -98,6 +121,8 @@ class OfflineStream::Impl {
|
||||
opts_.mel_opts.num_bins = 64;
|
||||
opts_.mel_opts.high_freq = 8000;
|
||||
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
@@ -115,52 +140,60 @@ class OfflineStream::Impl {
|
||||
|
||||
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) {
|
||||
if (sampling_rate != opts_.frame_opts.samp_freq) {
|
||||
if (sampling_rate != config_.sampling_rate) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
|
||||
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
|
||||
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
|
||||
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
|
||||
sampling_rate, config_.sampling_rate, lowpass_cutoff,
|
||||
lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
} else if (mfcc_) {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
mfcc_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
|
||||
samples.data(), samples.size());
|
||||
whisper_fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
|
||||
return;
|
||||
} // if (sampling_rate != opts_.frame_opts.samp_freq)
|
||||
} // if (sampling_rate != config_.sampling_rate)
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
} else if (mfcc_) {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
mfcc_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
int32_t FeatureDim() const {
|
||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames() const {
|
||||
int32_t n =
|
||||
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
|
||||
|
||||
int32_t n = fbank_ ? fbank_->NumFramesReady()
|
||||
: mfcc_ ? mfcc_->NumFramesReady()
|
||||
: whisper_fbank_->NumFramesReady();
|
||||
assert(n > 0 && "Please first call AcceptWaveform()");
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
@@ -170,8 +203,9 @@ class OfflineStream::Impl {
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f =
|
||||
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
|
||||
const float *f = fbank_ ? fbank_->GetFrame(i)
|
||||
: mfcc_ ? mfcc_->GetFrame(i)
|
||||
: whisper_fbank_->GetFrame(i);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
@@ -222,8 +256,10 @@ class OfflineStream::Impl {
|
||||
private:
|
||||
FeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
};
|
||||
|
||||
144
sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
Normal file
144
sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
Normal file
@@ -0,0 +1,144 @@
|
||||
// sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTeleSpeechCtcModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.telespeech_ctc);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.telespeech_ctc);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value /*features_length*/) {
|
||||
std::vector<int64_t> shape =
|
||||
features.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (static_cast<int32_t>(shape[0]) != 1) {
|
||||
SHERPA_ONNX_LOGE("This model supports only batch size 1. Given %d",
|
||||
static_cast<int32_t>(shape[0]));
|
||||
}
|
||||
|
||||
auto out = sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
std::vector<int64_t> logits_shape = {1};
|
||||
Ort::Value logits_length = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator_, logits_shape.data(), logits_shape.size());
|
||||
|
||||
int64_t *dst = logits_length.GetTensorMutableData<int64_t>();
|
||||
dst[0] = out[0].GetTensorTypeAndShapeInfo().GetShape()[0];
|
||||
|
||||
// (T, B, C) -> (B, T, C)
|
||||
Ort::Value logits = Transpose01(allocator_, &out[0]);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(logits));
|
||||
ans.push_back(std::move(logits_length));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
{
|
||||
auto shape =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
|
||||
vocab_size_ = shape[2];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0;
|
||||
int32_t subsampling_factor_ = 4;
|
||||
};
|
||||
|
||||
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineTeleSpeechCtcModel::~OfflineTeleSpeechCtcModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineTeleSpeechCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineTeleSpeechCtcModel::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
int32_t OfflineTeleSpeechCtcModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineTeleSpeechCtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
81
sherpa-onnx/csrc/offline-telespeech-ctc-model.h
Normal file
81
sherpa-onnx/csrc/offline-telespeech-ctc-model.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// sherpa-onnx/csrc/offline-telespeech-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the CTC model from
|
||||
* https://github.com/Tele-AI/TeleSpeech-ASR.
|
||||
*
|
||||
* See
|
||||
* https://github.com/lovemefan/telespeech-asr-python/blob/main/telespeechasr/onnx/onnx_infer.py
|
||||
* and
|
||||
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/test.py
|
||||
*/
|
||||
class OfflineTeleSpeechCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineTeleSpeechCtcModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTeleSpeechCtcModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineTeleSpeechCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const override;
|
||||
|
||||
/** SubsamplingFactor of the model
|
||||
*/
|
||||
int32_t SubsamplingFactor() const override;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const override;
|
||||
|
||||
// TeleSpeech CTC models do not support batch size > 1
|
||||
bool SupportBatchProcessing() const override { return false; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const override {
|
||||
return "per_feature";
|
||||
}
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
|
||||
@@ -66,7 +66,7 @@ bool OnlineModelConfig::Validate() const {
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user