Add C++ runtime for Tele-AI/TeleSpeech-ASR (#970)

This commit is contained in:
Fangjun Kuang
2024-06-05 00:26:40 +08:00
committed by GitHub
parent f8dbc10146
commit fd5a0d1e00
52 changed files with 1052 additions and 145 deletions

View File

@@ -39,6 +39,7 @@ set(sources
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-telespeech-ctc-model.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-greedy-search-nemo-decoder.cc
offline-transducer-model-config.cc

View File

@@ -56,22 +56,11 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
if (config_.is_mfcc) {
InitMfcc();
} else {
InitFbank();
}
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
@@ -101,35 +90,48 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
resampler_ = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} else {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
}
}
void InputFinished() const {
@@ -179,11 +181,56 @@ class FeatureExtractor::Impl {
return features;
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
private:
void InitFbank() {
opts_.frame_opts.dither = config_.dither;
opts_.frame_opts.snip_edges = config_.snip_edges;
opts_.frame_opts.samp_freq = config_.sampling_rate;
opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config_.high_freq;
opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void InitMfcc() {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;

View File

@@ -18,7 +18,10 @@ struct FeatureExtractorConfig {
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
// num_mel_bins
//
// Note: for mfcc, this value is also for num_mel_bins.
// The actual feature dimension is actuall num_ceps
int32_t feature_dim = 80;
// minimal frequency for Mel-filterbank, in Hz
@@ -69,6 +72,12 @@ struct FeatureExtractorConfig {
// for details
std::string nemo_normalize_type;
// for MFCC
int32_t num_ceps = 13;
bool use_energy = true;
bool is_mfcc = false;
std::string ToString() const;
void Register(ParseOptions *po);

View File

@@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -24,6 +25,7 @@ enum class ModelType {
kTdnn,
kZipformerCtc,
kWenetCtc,
kTeleSpeechCtc,
kUnknown,
};
@@ -63,6 +65,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"If you are using models from TeleSpeech, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/"
"add-metadata.py"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnknown;
@@ -78,6 +83,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) {
return ModelType::kTeleSpeechCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnknown;
@@ -97,6 +104,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
@@ -124,6 +133,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
@@ -147,6 +159,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
@@ -175,6 +189,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;

View File

@@ -19,6 +19,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
po->Register("tokens", &tokens, "Path to tokens.txt");
po->Register("num-threads", &num_threads,
@@ -33,7 +36,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc"
"tdnn, zipformer2_ctc, telespeech_ctc."
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
@@ -55,14 +58,14 @@ bool OfflineModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}
@@ -91,6 +94,14 @@ bool OfflineModelConfig::Validate() const {
return wenet_ctc.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
return false;
} else {
return true;
}
return transducer.Validate();
}
@@ -105,6 +116,7 @@ std::string OfflineModelConfig::ToString() const {
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";

View File

@@ -24,6 +24,7 @@ struct OfflineModelConfig {
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
std::string telespeech_ctc;
std::string tokens;
int32_t num_threads = 2;
@@ -52,6 +53,7 @@ struct OfflineModelConfig {
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
const std::string &modeling_unit,
@@ -63,6 +65,7 @@ struct OfflineModelConfig {
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),

View File

@@ -88,6 +88,17 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif
void Init() {
if (!config_.model_config.telespeech_ctc.empty()) {
config_.feat_config.snip_edges = true;
config_.feat_config.num_ceps = 40;
config_.feat_config.feature_dim = 40;
config_.feat_config.low_freq = 40;
config_.feat_config.high_freq = -200;
config_.feat_config.use_energy = false;
config_.feat_config.normalize_samples = false;
config_.feat_config.is_mfcc = true;
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false

View File

@@ -29,7 +29,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
@@ -53,6 +54,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
@@ -111,6 +114,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
@@ -133,7 +140,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
@@ -151,7 +159,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);
@@ -169,7 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
@@ -199,6 +209,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.wenet_ctc.model.empty()) {
model_filename = config.model_config.wenet_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
@@ -251,6 +263,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
@@ -273,7 +289,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
@@ -291,7 +308,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);

View File

@@ -57,22 +57,44 @@ class OfflineStream::Impl {
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
if (config.is_mfcc) {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
} else {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
}
explicit Impl(WhisperTag /*tag*/) {
@@ -81,6 +103,7 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
explicit Impl(CEDTag /*tag*/) {
@@ -98,6 +121,8 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
config_.sampling_rate = opts_.frame_opts.samp_freq;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
@@ -115,52 +140,60 @@ class OfflineStream::Impl {
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
int32_t n) {
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
whisper_fbank_->InputFinished();
}
return;
} // if (sampling_rate != opts_.frame_opts.samp_freq)
} // if (sampling_rate != config_.sampling_rate)
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
std::vector<float> GetFrames() const {
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
int32_t n = fbank_ ? fbank_->NumFramesReady()
: mfcc_ ? mfcc_->NumFramesReady()
: whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
@@ -170,8 +203,9 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
const float *f = fbank_ ? fbank_->GetFrame(i)
: mfcc_ ? mfcc_->GetFrame(i)
: whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
@@ -222,8 +256,10 @@ class OfflineStream::Impl {
private:
FeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};

View File

@@ -0,0 +1,144 @@
// sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineTeleSpeechCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value /*features_length*/) {
std::vector<int64_t> shape =
features.GetTensorTypeAndShapeInfo().GetShape();
if (static_cast<int32_t>(shape[0]) != 1) {
SHERPA_ONNX_LOGE("This model supports only batch size 1. Given %d",
static_cast<int32_t>(shape[0]));
}
auto out = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
std::vector<int64_t> logits_shape = {1};
Ort::Value logits_length = Ort::Value::CreateTensor<int64_t>(
allocator_, logits_shape.data(), logits_shape.size());
int64_t *dst = logits_length.GetTensorMutableData<int64_t>();
dst[0] = out[0].GetTensorTypeAndShapeInfo().GetShape()[0];
// (T, B, C) -> (B, T, C)
Ort::Value logits = Transpose01(allocator_, &out[0]);
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(logits));
ans.push_back(std::move(logits_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
vocab_size_ = shape[2];
}
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 4;
};
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTeleSpeechCtcModel::~OfflineTeleSpeechCtcModel() = default;
std::vector<Ort::Value> OfflineTeleSpeechCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineTeleSpeechCtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OfflineTeleSpeechCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineTeleSpeechCtcModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,81 @@
// sherpa-onnx/csrc/offline-telespeech-ctc-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the CTC model from
* https://github.com/Tele-AI/TeleSpeech-ASR.
*
* See
* https://github.com/lovemefan/telespeech-asr-python/blob/main/telespeechasr/onnx/onnx_infer.py
* and
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/test.py
*/
class OfflineTeleSpeechCtcModel : public OfflineCtcModel {
public:
explicit OfflineTeleSpeechCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineTeleSpeechCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** SubsamplingFactor of the model
*/
int32_t SubsamplingFactor() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// TeleSpeech CTC models do not support batch size > 1
bool SupportBatchProcessing() const override { return false; }
std::string FeatureNormalizationMethod() const override {
return "per_feature";
}
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_

View File

@@ -66,7 +66,7 @@ bool OnlineModelConfig::Validate() const {
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}