Support zipformer CTC ASR with whisper features. (#2319)
This commit is contained in:
@@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
|
||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||
if (config_.is_mfcc) {
|
||||
InitMfcc();
|
||||
} else if (config_.is_whisper) {
|
||||
InitWhisper();
|
||||
} else {
|
||||
InitFbank();
|
||||
}
|
||||
@@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
|
||||
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
|
||||
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
AcceptWaveformWrapper(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_->InputFinished();
|
||||
if (fbank_) {
|
||||
fbank_->InputFinished();
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->InputFinished();
|
||||
} else if (mfcc_) {
|
||||
mfcc_->InputFinished();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return fbank_->NumFramesReady();
|
||||
if (fbank_) {
|
||||
return fbank_->NumFramesReady();
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->NumFramesReady();
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->NumFramesReady();
|
||||
}
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return fbank_->IsLastFrame(frame);
|
||||
if (fbank_) {
|
||||
return fbank_->IsLastFrame(frame);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->IsLastFrame(frame);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->IsLastFrame(frame);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (frame_index + n > fbank_->NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
|
||||
fbank_->NumFramesReady());
|
||||
exit(-1);
|
||||
if (frame_index + n > NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
int32_t discard_num = frame_index - last_frame_index_;
|
||||
if (discard_num < 0) {
|
||||
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
||||
last_frame_index_, frame_index);
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
fbank_->Pop(discard_num);
|
||||
|
||||
int32_t feature_dim = fbank_->Dim();
|
||||
PopWrapper(discard_num);
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
std::vector<float> features(feature_dim * n);
|
||||
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f = fbank_->GetFrame(i + frame_index);
|
||||
const float *f = GetFrameWrapper(i + frame_index);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
@@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const {
|
||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||
if (fbank_ || whisper_fbank_) {
|
||||
return opts_.mel_opts.num_bins;
|
||||
} else if (mfcc_) {
|
||||
return mfcc_opts_.num_ceps;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return -1;
|
||||
}
|
||||
|
||||
private:
|
||||
void AcceptWaveformWrapper(float sampling_rate, const float *waveform,
|
||||
int32_t n) const {
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
const float *GetFrameWrapper(int32_t frame_index) const {
|
||||
if (fbank_) {
|
||||
return fbank_->GetFrame(frame_index);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->GetFrame(frame_index);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->GetFrame(frame_index);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PopWrapper(int32_t discard_num) const {
|
||||
if (fbank_) {
|
||||
fbank_->Pop(discard_num);
|
||||
return;
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->Pop(discard_num);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->Pop(discard_num);
|
||||
return;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
void InitFbank() {
|
||||
opts_.frame_opts.dither = config_.dither;
|
||||
opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
@@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void InitMfcc() {
|
||||
mfcc_opts_.frame_opts.dither = config_.dither;
|
||||
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
@@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
|
||||
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
||||
}
|
||||
|
||||
void InitWhisper() {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = config_.feature_dim;
|
||||
|
||||
knf::WhisperFeatureOptions whisper_opts;
|
||||
whisper_opts.frame_opts = opts_.frame_opts;
|
||||
whisper_opts.dim = config_.feature_dim;
|
||||
|
||||
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
FeatureExtractorConfig config_;
|
||||
|
||||
Reference in New Issue
Block a user