Add C++ runtime for Tele-AI/TeleSpeech-ASR (#970)

This commit is contained in:
Fangjun Kuang
2024-06-05 00:26:40 +08:00
committed by GitHub
parent f8dbc10146
commit fd5a0d1e00
52 changed files with 1052 additions and 145 deletions

View File

@@ -366,6 +366,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.bpe_vocab =
SHERPA_ONNX_OR(config->model_config.bpe_vocab, "");
recognizer_config.model_config.telespeech_ctc =
SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
recognizer_config.lm_config.model =
SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale =

View File

@@ -395,6 +395,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
// - cjkchar+bpe
const char *modeling_unit;
const char *bpe_vocab;
const char *telespeech_ctc;
} SherpaOnnxOfflineModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

View File

@@ -39,6 +39,7 @@ set(sources
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-telespeech-ctc-model.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-greedy-search-nemo-decoder.cc
offline-transducer-model-config.cc

View File

@@ -56,22 +56,11 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
if (config_.is_mfcc) {
InitMfcc();
} else {
InitFbank();
}
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
@@ -101,35 +90,48 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
resampler_ = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} else {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
}
}
void InputFinished() const {
@@ -179,11 +181,56 @@ class FeatureExtractor::Impl {
return features;
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
private:
void InitFbank() {
opts_.frame_opts.dither = config_.dither;
opts_.frame_opts.snip_edges = config_.snip_edges;
opts_.frame_opts.samp_freq = config_.sampling_rate;
opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config_.high_freq;
opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void InitMfcc() {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;

View File

@@ -18,7 +18,10 @@ struct FeatureExtractorConfig {
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
// num_mel_bins
//
// Note: for mfcc, this value is also for num_mel_bins.
// The actual feature dimension is actuall num_ceps
int32_t feature_dim = 80;
// minimal frequency for Mel-filterbank, in Hz
@@ -69,6 +72,12 @@ struct FeatureExtractorConfig {
// for details
std::string nemo_normalize_type;
// for MFCC
int32_t num_ceps = 13;
bool use_energy = true;
bool is_mfcc = false;
std::string ToString() const;
void Register(ParseOptions *po);

View File

@@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -24,6 +25,7 @@ enum class ModelType {
kTdnn,
kZipformerCtc,
kWenetCtc,
kTeleSpeechCtc,
kUnknown,
};
@@ -63,6 +65,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"If you are using models from TeleSpeech, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/"
"add-metadata.py"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnknown;
@@ -78,6 +83,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) {
return ModelType::kTeleSpeechCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnknown;
@@ -97,6 +104,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
@@ -124,6 +133,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
@@ -147,6 +159,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
@@ -175,6 +189,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;

View File

@@ -19,6 +19,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
po->Register("tokens", &tokens, "Path to tokens.txt");
po->Register("num-threads", &num_threads,
@@ -33,7 +36,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc"
"tdnn, zipformer2_ctc, telespeech_ctc."
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
@@ -55,14 +58,14 @@ bool OfflineModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}
@@ -91,6 +94,14 @@ bool OfflineModelConfig::Validate() const {
return wenet_ctc.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
return false;
} else {
return true;
}
return transducer.Validate();
}
@@ -105,6 +116,7 @@ std::string OfflineModelConfig::ToString() const {
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";

View File

@@ -24,6 +24,7 @@ struct OfflineModelConfig {
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
std::string telespeech_ctc;
std::string tokens;
int32_t num_threads = 2;
@@ -52,6 +53,7 @@ struct OfflineModelConfig {
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
const std::string &modeling_unit,
@@ -63,6 +65,7 @@ struct OfflineModelConfig {
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),

View File

@@ -88,6 +88,17 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif
void Init() {
if (!config_.model_config.telespeech_ctc.empty()) {
config_.feat_config.snip_edges = true;
config_.feat_config.num_ceps = 40;
config_.feat_config.feature_dim = 40;
config_.feat_config.low_freq = 40;
config_.feat_config.high_freq = -200;
config_.feat_config.use_energy = false;
config_.feat_config.normalize_samples = false;
config_.feat_config.is_mfcc = true;
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false

View File

@@ -29,7 +29,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
@@ -53,6 +54,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
@@ -111,6 +114,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
@@ -133,7 +140,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
@@ -151,7 +159,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);
@@ -169,7 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
@@ -199,6 +209,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.wenet_ctc.model.empty()) {
model_filename = config.model_config.wenet_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
@@ -251,6 +263,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
@@ -273,7 +289,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
@@ -291,7 +308,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);

View File

@@ -57,22 +57,44 @@ class OfflineStream::Impl {
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
if (config.is_mfcc) {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
} else {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
}
explicit Impl(WhisperTag /*tag*/) {
@@ -81,6 +103,7 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
explicit Impl(CEDTag /*tag*/) {
@@ -98,6 +121,8 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
config_.sampling_rate = opts_.frame_opts.samp_freq;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
@@ -115,52 +140,60 @@ class OfflineStream::Impl {
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
int32_t n) {
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
whisper_fbank_->InputFinished();
}
return;
} // if (sampling_rate != opts_.frame_opts.samp_freq)
} // if (sampling_rate != config_.sampling_rate)
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
std::vector<float> GetFrames() const {
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
int32_t n = fbank_ ? fbank_->NumFramesReady()
: mfcc_ ? mfcc_->NumFramesReady()
: whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
@@ -170,8 +203,9 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
const float *f = fbank_ ? fbank_->GetFrame(i)
: mfcc_ ? mfcc_->GetFrame(i)
: whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
@@ -222,8 +256,10 @@ class OfflineStream::Impl {
private:
FeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};

View File

@@ -0,0 +1,144 @@
// sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineTeleSpeechCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value /*features_length*/) {
std::vector<int64_t> shape =
features.GetTensorTypeAndShapeInfo().GetShape();
if (static_cast<int32_t>(shape[0]) != 1) {
SHERPA_ONNX_LOGE("This model supports only batch size 1. Given %d",
static_cast<int32_t>(shape[0]));
}
auto out = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
std::vector<int64_t> logits_shape = {1};
Ort::Value logits_length = Ort::Value::CreateTensor<int64_t>(
allocator_, logits_shape.data(), logits_shape.size());
int64_t *dst = logits_length.GetTensorMutableData<int64_t>();
dst[0] = out[0].GetTensorTypeAndShapeInfo().GetShape()[0];
// (T, B, C) -> (B, T, C)
Ort::Value logits = Transpose01(allocator_, &out[0]);
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(logits));
ans.push_back(std::move(logits_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
vocab_size_ = shape[2];
}
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 4;
};
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTeleSpeechCtcModel::~OfflineTeleSpeechCtcModel() = default;
std::vector<Ort::Value> OfflineTeleSpeechCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineTeleSpeechCtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OfflineTeleSpeechCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineTeleSpeechCtcModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,81 @@
// sherpa-onnx/csrc/offline-telespeech-ctc-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the CTC model from
* https://github.com/Tele-AI/TeleSpeech-ASR.
*
* See
* https://github.com/lovemefan/telespeech-asr-python/blob/main/telespeechasr/onnx/onnx_infer.py
* and
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/test.py
*/
class OfflineTeleSpeechCtcModel : public OfflineCtcModel {
public:
explicit OfflineTeleSpeechCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineTeleSpeechCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** SubsamplingFactor of the model
*/
int32_t SubsamplingFactor() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// TeleSpeech CTC models do not support batch size > 1
bool SupportBatchProcessing() const override { return false; }
std::string FeatureNormalizationMethod() const override {
return "per_feature";
}
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_

View File

@@ -66,7 +66,7 @@ bool OnlineModelConfig::Validate() const {
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}

View File

@@ -7,6 +7,7 @@ public class OfflineModelConfig {
private final OfflineParaformerModelConfig paraformer;
private final OfflineWhisperModelConfig whisper;
private final OfflineNemoEncDecCtcModelConfig nemo;
private final String teleSpeech;
private final String tokens;
private final int numThreads;
private final boolean debug;
@@ -21,6 +22,7 @@ public class OfflineModelConfig {
this.paraformer = builder.paraformer;
this.whisper = builder.whisper;
this.nemo = builder.nemo;
this.teleSpeech = builder.teleSpeech;
this.tokens = builder.tokens;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
@@ -74,11 +76,16 @@ public class OfflineModelConfig {
return bpeVocab;
}
public String getTeleSpeech() {
return teleSpeech;
}
public static class Builder {
private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build();
private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
private String teleSpeech = "";
private String tokens = "";
private int numThreads = 1;
private boolean debug = true;
@@ -106,6 +113,12 @@ public class OfflineModelConfig {
return this;
}
public Builder setTeleSpeech(String teleSpeech) {
this.teleSpeech = teleSpeech;
return this;
}
public Builder setWhisper(OfflineWhisperModelConfig whisper) {
this.whisper = whisper;
return this;

View File

@@ -172,6 +172,12 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
ans.model_config.nemo_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "teleSpeech", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.telespeech_ctc = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}

View File

@@ -35,6 +35,7 @@ data class OfflineModelConfig(
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(),
var teleSpeech: String = "",
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
@@ -272,6 +273,15 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
tokens = "$modelDir/tokens.txt",
)
}
11 -> {
val modelDir = "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04"
return OfflineModelConfig(
teleSpeech = "$modelDir/model.int8.onnx",
tokens = "$modelDir/tokens.txt",
modelType = "tele_speech",
)
}
}
return null
}

View File

@@ -29,25 +29,27 @@ void PybindOfflineModelConfig(py::module *m) {
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &,
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def(
py::init<
const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
@@ -55,6 +57,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)

View File

@@ -211,6 +211,71 @@ class OfflineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_telespeech_ctc(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 40,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model. It is
ignored and is hard-coded in C++ to 40.
feature_dim:
Dimension of the feature used to train the model. It is ignored
and is hard-coded in C++ to 40.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
telespeech_ctc=model,
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="nemo_ctc",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_nemo_ctc(
cls,