Support CED models (#792)

This commit is contained in:
Fangjun Kuang
2024-04-19 15:20:37 +08:00
committed by GitHub
parent d97a283dbb
commit c1608b3524
33 changed files with 605 additions and 46 deletions

View File

@@ -92,15 +92,32 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
explicit Impl(WhisperTag /*tag*/) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
explicit Impl(CEDTag /*tag*/) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
opts_.frame_opts.frame_length_ms = 32;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= {}*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default;