Support whisper models (#238)

This commit is contained in:
Fangjun Kuang
2023-08-07 12:34:18 +08:00
committed by GitHub
parent 64efbd82af
commit 45b9d4ab37
39 changed files with 1836 additions and 52 deletions

View File

@@ -86,6 +86,15 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -117,20 +126,35 @@ class OfflineStream::Impl {
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->InputFinished();
}
return;
} // if (sampling_rate != opts_.frame_opts.samp_freq)
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
std::vector<float> GetFrames() const {
int32_t n = fbank_->NumFramesReady();
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
@@ -140,7 +164,8 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i);
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
@@ -191,6 +216,7 @@ class OfflineStream::Impl {
private:
OfflineFeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,