diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 95c6f82a..82645849 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -60,6 +60,8 @@ class FeatureExtractor::Impl { explicit Impl(const FeatureExtractorConfig &config) : config_(config) { if (config_.is_mfcc) { InitMfcc(); + } else if (config_.is_whisper) { + InitWhisper(); } else { InitFbank(); } @@ -92,13 +94,9 @@ class FeatureExtractor::Impl { std::vector samples; resampler_->Resample(waveform, n, false, &samples); - if (fbank_) { - fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), - samples.size()); - } else { - mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), - samples.size()); - } + + AcceptWaveformWrapper(config_.sampling_rate, samples.data(), + samples.size()); return; } @@ -119,61 +117,81 @@ class FeatureExtractor::Impl { std::vector samples; resampler_->Resample(waveform, n, false, &samples); - if (fbank_) { - fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), - samples.size()); - } else { - mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), - samples.size()); - } + + AcceptWaveformWrapper(config_.sampling_rate, samples.data(), + samples.size()); + return; } - if (fbank_) { - fbank_->AcceptWaveform(sampling_rate, waveform, n); - } else { - mfcc_->AcceptWaveform(sampling_rate, waveform, n); - } + AcceptWaveformWrapper(sampling_rate, waveform, n); } void InputFinished() const { std::lock_guard lock(mutex_); - fbank_->InputFinished(); + if (fbank_) { + fbank_->InputFinished(); + } else if (whisper_fbank_) { + whisper_fbank_->InputFinished(); + } else if (mfcc_) { + mfcc_->InputFinished(); + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); } int32_t NumFramesReady() const { - std::lock_guard lock(mutex_); - return fbank_->NumFramesReady(); + if (fbank_) { + return fbank_->NumFramesReady(); + } else if (whisper_fbank_) { + return whisper_fbank_->NumFramesReady(); + } else if (mfcc_) { + return mfcc_->NumFramesReady(); + } + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + return -1; } bool IsLastFrame(int32_t frame) const { std::lock_guard lock(mutex_); - return fbank_->IsLastFrame(frame); + if (fbank_) { + return fbank_->IsLastFrame(frame); + } else if (whisper_fbank_) { + return whisper_fbank_->IsLastFrame(frame); + } else if (mfcc_) { + return mfcc_->IsLastFrame(frame); + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + return false; } std::vector GetFrames(int32_t frame_index, int32_t n) { std::lock_guard lock(mutex_); - if (frame_index + n > fbank_->NumFramesReady()) { - SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, - fbank_->NumFramesReady()); - exit(-1); + if (frame_index + n > NumFramesReady()) { + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady()); + SHERPA_ONNX_EXIT(-1); } int32_t discard_num = frame_index - last_frame_index_; if (discard_num < 0) { SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", last_frame_index_, frame_index); - exit(-1); + SHERPA_ONNX_EXIT(-1); } - fbank_->Pop(discard_num); - int32_t feature_dim = fbank_->Dim(); + PopWrapper(discard_num); + + int32_t feature_dim = FeatureDim(); std::vector features(feature_dim * n); float *p = features.data(); for (int32_t i = 0; i != n; ++i) { - const float *f = fbank_->GetFrame(i + frame_index); + const float *f = GetFrameWrapper(i + frame_index); std::copy(f, f + feature_dim, p); p += feature_dim; } @@ -184,10 +202,65 @@ class FeatureExtractor::Impl { } int32_t FeatureDim() const { - return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; + if (fbank_ || whisper_fbank_) { + return opts_.mel_opts.num_bins; + } else if (mfcc_) { + return mfcc_opts_.num_ceps; + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + return -1; } private: + void AcceptWaveformWrapper(float sampling_rate, const float *waveform, + int32_t n) const { + if (fbank_) { + fbank_->AcceptWaveform(sampling_rate, waveform, n); + return; + } else if (whisper_fbank_) { + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); + return; + } else if (mfcc_) { + mfcc_->AcceptWaveform(sampling_rate, waveform, n); + return; + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + } + + const float *GetFrameWrapper(int32_t frame_index) const { + if (fbank_) { + return fbank_->GetFrame(frame_index); + } else if (whisper_fbank_) { + return whisper_fbank_->GetFrame(frame_index); + } else if (mfcc_) { + return mfcc_->GetFrame(frame_index); + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + return nullptr; + } + + void PopWrapper(int32_t discard_num) const { + if (fbank_) { + fbank_->Pop(discard_num); + return; + } else if (whisper_fbank_) { + whisper_fbank_->Pop(discard_num); + return; + } else if (mfcc_) { + mfcc_->Pop(discard_num); + return; + } + + SHERPA_ONNX_LOGE("unreachable code"); + SHERPA_ONNX_EXIT(-1); + } + void InitFbank() { opts_.frame_opts.dither = config_.dither; opts_.frame_opts.snip_edges = config_.snip_edges; @@ -208,6 +281,7 @@ class FeatureExtractor::Impl { fbank_ = std::make_unique(opts_); } + void InitMfcc() { mfcc_opts_.frame_opts.dither = config_.dither; mfcc_opts_.frame_opts.snip_edges = config_.snip_edges; @@ -232,9 +306,23 @@ class FeatureExtractor::Impl { mfcc_ = std::make_unique(mfcc_opts_); } + void InitWhisper() { + config_.normalize_samples = true; + opts_.frame_opts.samp_freq = 16000; + opts_.mel_opts.num_bins = config_.feature_dim; + + knf::WhisperFeatureOptions whisper_opts; + whisper_opts.frame_opts = opts_.frame_opts; + whisper_opts.dim = config_.feature_dim; + + whisper_fbank_ = std::make_unique(whisper_opts); + config_.sampling_rate = opts_.frame_opts.samp_freq; + } + private: std::unique_ptr fbank_; std::unique_ptr mfcc_; + std::unique_ptr whisper_fbank_; knf::FbankOptions opts_; knf::MfccOptions mfcc_opts_; FeatureExtractorConfig config_; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index d10b486b..b8e99320 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -79,6 +79,8 @@ struct FeatureExtractorConfig { bool is_mfcc = false; + bool is_whisper = false; + bool round_to_power_of_two = true; std::string ToString() const; diff --git a/sherpa-onnx/csrc/online-ctc-model.h b/sherpa-onnx/csrc/online-ctc-model.h index bd01bc54..6886a3c6 100644 --- a/sherpa-onnx/csrc/online-ctc-model.h +++ b/sherpa-onnx/csrc/online-ctc-model.h @@ -77,6 +77,8 @@ class OnlineCtcModel { // Return true if the model supports batch size > 1 virtual bool SupportBatchProcessing() const { return true; } + + virtual bool UseWhisperFeature() const { return false; } }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index a675fd72..42b7fd25 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -15,6 +15,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-whisper-model.h" #include "sherpa-onnx/csrc/online-ctc-decoder.h" #include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" @@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { config_.feat_config.normalize_samples = false; } + if (model_->UseWhisperFeature()) { + config_.feat_config.is_whisper = true; + } + InitDecoder(); } @@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { config_.feat_config.normalize_samples = false; } + if (model_->UseWhisperFeature()) { + config_.feat_config.is_whisper = true; + } + InitDecoder(); } @@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = ss[i]->GetFrames(num_processed_frames, chunk_length); + if (config_.feat_config.is_whisper) { + OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length, + feat_dim); + } // Question: should num_processed_frames include chunk_shift? ss[i]->GetNumProcessedFrames() += chunk_shift; @@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { const auto num_processed_frames = s->GetNumProcessedFrames(); std::vector frames = s->GetFrames(num_processed_frames, chunk_length); + + if (config_.feat_config.is_whisper) { + OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length, + feat_dim); + } + s->GetNumProcessedFrames() += chunk_shift; auto memory_info = diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 972133c7..f7abaa66 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -19,34 +19,51 @@ class OnlineStream::Impl { : feat_extractor_(config), context_graph_(std::move(context_graph)) {} void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + std::lock_guard lock(mutex_); feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); } - void InputFinished() const { feat_extractor_.InputFinished(); } + void InputFinished() const { + std::lock_guard lock(mutex_); + feat_extractor_.InputFinished(); + } int32_t NumFramesReady() const { + std::lock_guard lock(mutex_); return feat_extractor_.NumFramesReady() - start_frame_index_; } bool IsLastFrame(int32_t frame) const { + std::lock_guard lock(mutex_); return feat_extractor_.IsLastFrame(frame); } std::vector GetFrames(int32_t frame_index, int32_t n) const { + std::lock_guard lock(mutex_); return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); } void Reset() { + std::lock_guard lock(mutex_); // we don't reset the feature extractor start_frame_index_ += num_processed_frames_; num_processed_frames_ = 0; } - int32_t &GetNumProcessedFrames() { return num_processed_frames_; } + int32_t &GetNumProcessedFrames() { + std::lock_guard lock(mutex_); + return num_processed_frames_; + } - int32_t GetNumFramesSinceStart() const { return start_frame_index_; } + int32_t GetNumFramesSinceStart() const { + std::lock_guard lock(mutex_); + return start_frame_index_; + } - int32_t &GetCurrentSegment() { return segment_; } + int32_t &GetCurrentSegment() { + std::lock_guard lock(mutex_); + return segment_; + } void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } @@ -125,6 +142,7 @@ class OnlineStream::Impl { private: FeatureExtractor feat_extractor_; + mutable std::mutex mutex_; /// For contextual-biasing ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc index 1cda9a62..f7cccc43 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl { int32_t ChunkShift() const { return decode_chunk_len_; } + bool UseWhisperFeature() const { return use_whisper_feature_; } + OrtAllocator *Allocator() { return allocator_; } // Return a vector containing 3 tensors @@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl { SHERPA_ONNX_READ_META_DATA(T_, "T"); SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + std::string feature_type; + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", ""); + if (feature_type == "whisper") { + use_whisper_feature_ = true; + } + { auto shape = sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); @@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl { int32_t T_ = 0; int32_t decode_chunk_len_ = 0; int32_t vocab_size_ = 0; + + // for models from + // https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head + bool use_whisper_feature_ = false; }; OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( @@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const { return impl_->ChunkShift(); } +bool OnlineZipformer2CtcModel::UseWhisperFeature() const { + return impl_->UseWhisperFeature(); +} + OrtAllocator *OnlineZipformer2CtcModel::Allocator() const { return impl_->Allocator(); } diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.h b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h index 32ddf212..3cbd4cc7 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h @@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel { // before we process the next chunk. int32_t ChunkShift() const override; + bool UseWhisperFeature() const override; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc index a3c1294a..22f22124 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc @@ -130,7 +130,7 @@ for a list of pre-trained models to download. } if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback, - nullptr /* user_data */)) { + s.get())) { fprintf(stderr, "portaudio error: %d\n", device_index); exit(EXIT_FAILURE); }