Support zipformer CTC ASR with whisper features. (#2319)
This commit is contained in:
@@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
|
||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||
if (config_.is_mfcc) {
|
||||
InitMfcc();
|
||||
} else if (config_.is_whisper) {
|
||||
InitWhisper();
|
||||
} else {
|
||||
InitFbank();
|
||||
}
|
||||
@@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
|
||||
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
|
||||
|
||||
std::vector<float> samples;
|
||||
resampler_->Resample(waveform, n, false, &samples);
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
}
|
||||
|
||||
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
} else {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
AcceptWaveformWrapper(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_->InputFinished();
|
||||
if (fbank_) {
|
||||
fbank_->InputFinished();
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->InputFinished();
|
||||
} else if (mfcc_) {
|
||||
mfcc_->InputFinished();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return fbank_->NumFramesReady();
|
||||
if (fbank_) {
|
||||
return fbank_->NumFramesReady();
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->NumFramesReady();
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->NumFramesReady();
|
||||
}
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return fbank_->IsLastFrame(frame);
|
||||
if (fbank_) {
|
||||
return fbank_->IsLastFrame(frame);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->IsLastFrame(frame);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->IsLastFrame(frame);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (frame_index + n > fbank_->NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
|
||||
fbank_->NumFramesReady());
|
||||
exit(-1);
|
||||
if (frame_index + n > NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
int32_t discard_num = frame_index - last_frame_index_;
|
||||
if (discard_num < 0) {
|
||||
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
||||
last_frame_index_, frame_index);
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
fbank_->Pop(discard_num);
|
||||
|
||||
int32_t feature_dim = fbank_->Dim();
|
||||
PopWrapper(discard_num);
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
std::vector<float> features(feature_dim * n);
|
||||
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f = fbank_->GetFrame(i + frame_index);
|
||||
const float *f = GetFrameWrapper(i + frame_index);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
@@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const {
|
||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||
if (fbank_ || whisper_fbank_) {
|
||||
return opts_.mel_opts.num_bins;
|
||||
} else if (mfcc_) {
|
||||
return mfcc_opts_.num_ceps;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return -1;
|
||||
}
|
||||
|
||||
private:
|
||||
void AcceptWaveformWrapper(float sampling_rate, const float *waveform,
|
||||
int32_t n) const {
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
const float *GetFrameWrapper(int32_t frame_index) const {
|
||||
if (fbank_) {
|
||||
return fbank_->GetFrame(frame_index);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->GetFrame(frame_index);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->GetFrame(frame_index);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PopWrapper(int32_t discard_num) const {
|
||||
if (fbank_) {
|
||||
fbank_->Pop(discard_num);
|
||||
return;
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->Pop(discard_num);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->Pop(discard_num);
|
||||
return;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
void InitFbank() {
|
||||
opts_.frame_opts.dither = config_.dither;
|
||||
opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
@@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void InitMfcc() {
|
||||
mfcc_opts_.frame_opts.dither = config_.dither;
|
||||
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||
@@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
|
||||
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
||||
}
|
||||
|
||||
void InitWhisper() {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = config_.feature_dim;
|
||||
|
||||
knf::WhisperFeatureOptions whisper_opts;
|
||||
whisper_opts.frame_opts = opts_.frame_opts;
|
||||
whisper_opts.dim = config_.feature_dim;
|
||||
|
||||
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
FeatureExtractorConfig config_;
|
||||
|
||||
@@ -79,6 +79,8 @@ struct FeatureExtractorConfig {
|
||||
|
||||
bool is_mfcc = false;
|
||||
|
||||
bool is_whisper = false;
|
||||
|
||||
bool round_to_power_of_two = true;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
@@ -77,6 +77,8 @@ class OnlineCtcModel {
|
||||
|
||||
// Return true if the model supports batch size > 1
|
||||
virtual bool SupportBatchProcessing() const { return true; }
|
||||
|
||||
virtual bool UseWhisperFeature() const { return false; }
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
@@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
if (model_->UseWhisperFeature()) {
|
||||
config_.feat_config.is_whisper = true;
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
@@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
if (model_->UseWhisperFeature()) {
|
||||
config_.feat_config.is_whisper = true;
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
@@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_length);
|
||||
if (config_.feat_config.is_whisper) {
|
||||
OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length,
|
||||
feat_dim);
|
||||
}
|
||||
|
||||
// Question: should num_processed_frames include chunk_shift?
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
@@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||
std::vector<float> frames =
|
||||
s->GetFrames(num_processed_frames, chunk_length);
|
||||
|
||||
if (config_.feat_config.is_whisper) {
|
||||
OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length,
|
||||
feat_dim);
|
||||
}
|
||||
|
||||
s->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
auto memory_info =
|
||||
|
||||
@@ -19,34 +19,51 @@ class OnlineStream::Impl {
|
||||
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() const { feat_extractor_.InputFinished(); }
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
feat_extractor_.InputFinished();
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// we don't reset the feature extractor
|
||||
start_frame_index_ += num_processed_frames_;
|
||||
num_processed_frames_ = 0;
|
||||
}
|
||||
|
||||
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
||||
int32_t &GetNumProcessedFrames() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return num_processed_frames_;
|
||||
}
|
||||
|
||||
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
|
||||
int32_t GetNumFramesSinceStart() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return start_frame_index_;
|
||||
}
|
||||
|
||||
int32_t &GetCurrentSegment() { return segment_; }
|
||||
int32_t &GetCurrentSegment() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return segment_;
|
||||
}
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||
|
||||
@@ -125,6 +142,7 @@ class OnlineStream::Impl {
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
mutable std::mutex mutex_;
|
||||
/// For contextual-biasing
|
||||
ContextGraphPtr context_graph_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
|
||||
@@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
|
||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||
|
||||
bool UseWhisperFeature() const { return use_whisper_feature_; }
|
||||
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 3 tensors
|
||||
@@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||
|
||||
std::string feature_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
|
||||
if (feature_type == "whisper") {
|
||||
use_whisper_feature_ = true;
|
||||
}
|
||||
|
||||
{
|
||||
auto shape =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
|
||||
@@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl {
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
|
||||
// for models from
|
||||
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
|
||||
bool use_whisper_feature_ = false;
|
||||
};
|
||||
|
||||
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
|
||||
@@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
bool OnlineZipformer2CtcModel::UseWhisperFeature() const {
|
||||
return impl_->UseWhisperFeature();
|
||||
}
|
||||
|
||||
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
@@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel {
|
||||
// before we process the next chunk.
|
||||
int32_t ChunkShift() const override;
|
||||
|
||||
bool UseWhisperFeature() const override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -130,7 +130,7 @@ for a list of pre-trained models to download.
|
||||
}
|
||||
|
||||
if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback,
|
||||
nullptr /* user_data */)) {
|
||||
s.get())) {
|
||||
fprintf(stderr, "portaudio error: %d\n", device_index);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user