Support zipformer CTC ASR with whisper features. (#2319)

This commit is contained in:
Fangjun Kuang
2025-06-27 00:15:11 +08:00
committed by GitHub
parent 282211c01f
commit 54bf3732d9
8 changed files with 184 additions and 37 deletions

View File

@@ -15,6 +15,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
@@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
}
@@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
}
@@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_length);
if (config_.feat_config.is_whisper) {
OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length,
feat_dim);
}
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
@@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const auto num_processed_frames = s->GetNumProcessedFrames();
std::vector<float> frames =
s->GetFrames(num_processed_frames, chunk_length);
if (config_.feat_config.is_whisper) {
OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length,
feat_dim);
}
s->GetNumProcessedFrames() += chunk_shift;
auto memory_info =