Support whisper models (#238)
This commit is contained in:
@@ -86,6 +86,15 @@ class OfflineStream::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
||||
: context_graph_(context_graph) {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = 80;
|
||||
whisper_fbank_ =
|
||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
@@ -117,20 +126,35 @@ class OfflineStream::Impl {
|
||||
lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
return;
|
||||
}
|
||||
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
|
||||
samples.data(), samples.size());
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
|
||||
return;
|
||||
} // if (sampling_rate != opts_.frame_opts.samp_freq)
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
|
||||
std::vector<float> GetFrames() const {
|
||||
int32_t n = fbank_->NumFramesReady();
|
||||
int32_t n =
|
||||
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
|
||||
|
||||
assert(n > 0 && "Please first call AcceptWaveform()");
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
@@ -140,7 +164,8 @@ class OfflineStream::Impl {
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f = fbank_->GetFrame(i);
|
||||
const float *f =
|
||||
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
@@ -191,6 +216,7 @@ class OfflineStream::Impl {
|
||||
private:
|
||||
OfflineFeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
|
||||
Reference in New Issue
Block a user