Support zipformer CTC ASR with whisper features. (#2319)

This commit is contained in:
Fangjun Kuang
2025-06-27 00:15:11 +08:00
committed by GitHub
parent 282211c01f
commit 54bf3732d9
8 changed files with 184 additions and 37 deletions

View File

@@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
if (config_.is_mfcc) {
InitMfcc();
} else if (config_.is_whisper) {
InitWhisper();
} else {
InitFbank();
}
@@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
samples.size());
return;
}
@@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
samples.size());
return;
}
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} else {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
}
AcceptWaveformWrapper(sampling_rate, waveform, n);
}
void InputFinished() const {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
if (fbank_) {
fbank_->InputFinished();
} else if (whisper_fbank_) {
whisper_fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->InputFinished();
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
int32_t NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->NumFramesReady();
if (fbank_) {
return fbank_->NumFramesReady();
} else if (whisper_fbank_) {
return whisper_fbank_->NumFramesReady();
} else if (mfcc_) {
return mfcc_->NumFramesReady();
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return -1;
}
bool IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->IsLastFrame(frame);
if (fbank_) {
return fbank_->IsLastFrame(frame);
} else if (whisper_fbank_) {
return whisper_fbank_->IsLastFrame(frame);
} else if (mfcc_) {
return mfcc_->IsLastFrame(frame);
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return false;
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (frame_index + n > fbank_->NumFramesReady()) {
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
fbank_->NumFramesReady());
exit(-1);
if (frame_index + n > NumFramesReady()) {
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady());
SHERPA_ONNX_EXIT(-1);
}
int32_t discard_num = frame_index - last_frame_index_;
if (discard_num < 0) {
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
last_frame_index_, frame_index);
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
fbank_->Pop(discard_num);
int32_t feature_dim = fbank_->Dim();
PopWrapper(discard_num);
int32_t feature_dim = FeatureDim();
std::vector<float> features(feature_dim * n);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i + frame_index);
const float *f = GetFrameWrapper(i + frame_index);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
@@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
}
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
if (fbank_ || whisper_fbank_) {
return opts_.mel_opts.num_bins;
} else if (mfcc_) {
return mfcc_opts_.num_ceps;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return -1;
}
private:
void AcceptWaveformWrapper(float sampling_rate, const float *waveform,
int32_t n) const {
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (whisper_fbank_) {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
return;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
const float *GetFrameWrapper(int32_t frame_index) const {
if (fbank_) {
return fbank_->GetFrame(frame_index);
} else if (whisper_fbank_) {
return whisper_fbank_->GetFrame(frame_index);
} else if (mfcc_) {
return mfcc_->GetFrame(frame_index);
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return nullptr;
}
void PopWrapper(int32_t discard_num) const {
if (fbank_) {
fbank_->Pop(discard_num);
return;
} else if (whisper_fbank_) {
whisper_fbank_->Pop(discard_num);
return;
} else if (mfcc_) {
mfcc_->Pop(discard_num);
return;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
void InitFbank() {
opts_.frame_opts.dither = config_.dither;
opts_.frame_opts.snip_edges = config_.snip_edges;
@@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void InitMfcc() {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
@@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
}
void InitWhisper() {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = config_.feature_dim;
knf::WhisperFeatureOptions whisper_opts;
whisper_opts.frame_opts = opts_.frame_opts;
whisper_opts.dim = config_.feature_dim;
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;