Support zipformer CTC ASR with whisper features. (#2319)
This commit is contained in:
@@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
|
|||||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||||
if (config_.is_mfcc) {
|
if (config_.is_mfcc) {
|
||||||
InitMfcc();
|
InitMfcc();
|
||||||
|
} else if (config_.is_whisper) {
|
||||||
|
InitWhisper();
|
||||||
} else {
|
} else {
|
||||||
InitFbank();
|
InitFbank();
|
||||||
}
|
}
|
||||||
@@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
|
|||||||
|
|
||||||
std::vector<float> samples;
|
std::vector<float> samples;
|
||||||
resampler_->Resample(waveform, n, false, &samples);
|
resampler_->Resample(waveform, n, false, &samples);
|
||||||
if (fbank_) {
|
|
||||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||||
samples.size());
|
samples.size());
|
||||||
} else {
|
|
||||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
|
||||||
samples.size());
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
|
|||||||
|
|
||||||
std::vector<float> samples;
|
std::vector<float> samples;
|
||||||
resampler_->Resample(waveform, n, false, &samples);
|
resampler_->Resample(waveform, n, false, &samples);
|
||||||
if (fbank_) {
|
|
||||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
|
||||||
samples.size());
|
samples.size());
|
||||||
} else {
|
|
||||||
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
|
||||||
samples.size());
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fbank_) {
|
AcceptWaveformWrapper(sampling_rate, waveform, n);
|
||||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
|
||||||
} else {
|
|
||||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void InputFinished() const {
|
void InputFinished() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
fbank_->InputFinished();
|
if (fbank_) {
|
||||||
|
fbank_->InputFinished();
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
whisper_fbank_->InputFinished();
|
||||||
|
} else if (mfcc_) {
|
||||||
|
mfcc_->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t NumFramesReady() const {
|
int32_t NumFramesReady() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
if (fbank_) {
|
||||||
return fbank_->NumFramesReady();
|
return fbank_->NumFramesReady();
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
return whisper_fbank_->NumFramesReady();
|
||||||
|
} else if (mfcc_) {
|
||||||
|
return mfcc_->NumFramesReady();
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsLastFrame(int32_t frame) const {
|
bool IsLastFrame(int32_t frame) const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return fbank_->IsLastFrame(frame);
|
if (fbank_) {
|
||||||
|
return fbank_->IsLastFrame(frame);
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
return whisper_fbank_->IsLastFrame(frame);
|
||||||
|
} else if (mfcc_) {
|
||||||
|
return mfcc_->IsLastFrame(frame);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
if (frame_index + n > fbank_->NumFramesReady()) {
|
if (frame_index + n > NumFramesReady()) {
|
||||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
|
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||||
fbank_->NumFramesReady());
|
SHERPA_ONNX_EXIT(-1);
|
||||||
exit(-1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t discard_num = frame_index - last_frame_index_;
|
int32_t discard_num = frame_index - last_frame_index_;
|
||||||
if (discard_num < 0) {
|
if (discard_num < 0) {
|
||||||
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
||||||
last_frame_index_, frame_index);
|
last_frame_index_, frame_index);
|
||||||
exit(-1);
|
SHERPA_ONNX_EXIT(-1);
|
||||||
}
|
}
|
||||||
fbank_->Pop(discard_num);
|
|
||||||
|
|
||||||
int32_t feature_dim = fbank_->Dim();
|
PopWrapper(discard_num);
|
||||||
|
|
||||||
|
int32_t feature_dim = FeatureDim();
|
||||||
std::vector<float> features(feature_dim * n);
|
std::vector<float> features(feature_dim * n);
|
||||||
|
|
||||||
float *p = features.data();
|
float *p = features.data();
|
||||||
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
const float *f = fbank_->GetFrame(i + frame_index);
|
const float *f = GetFrameWrapper(i + frame_index);
|
||||||
std::copy(f, f + feature_dim, p);
|
std::copy(f, f + feature_dim, p);
|
||||||
p += feature_dim;
|
p += feature_dim;
|
||||||
}
|
}
|
||||||
@@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t FeatureDim() const {
|
int32_t FeatureDim() const {
|
||||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
if (fbank_ || whisper_fbank_) {
|
||||||
|
return opts_.mel_opts.num_bins;
|
||||||
|
} else if (mfcc_) {
|
||||||
|
return mfcc_opts_.num_ceps;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void AcceptWaveformWrapper(float sampling_rate, const float *waveform,
|
||||||
|
int32_t n) const {
|
||||||
|
if (fbank_) {
|
||||||
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
return;
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
return;
|
||||||
|
} else if (mfcc_) {
|
||||||
|
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float *GetFrameWrapper(int32_t frame_index) const {
|
||||||
|
if (fbank_) {
|
||||||
|
return fbank_->GetFrame(frame_index);
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
return whisper_fbank_->GetFrame(frame_index);
|
||||||
|
} else if (mfcc_) {
|
||||||
|
return mfcc_->GetFrame(frame_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PopWrapper(int32_t discard_num) const {
|
||||||
|
if (fbank_) {
|
||||||
|
fbank_->Pop(discard_num);
|
||||||
|
return;
|
||||||
|
} else if (whisper_fbank_) {
|
||||||
|
whisper_fbank_->Pop(discard_num);
|
||||||
|
return;
|
||||||
|
} else if (mfcc_) {
|
||||||
|
mfcc_->Pop(discard_num);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("unreachable code");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
void InitFbank() {
|
void InitFbank() {
|
||||||
opts_.frame_opts.dither = config_.dither;
|
opts_.frame_opts.dither = config_.dither;
|
||||||
opts_.frame_opts.snip_edges = config_.snip_edges;
|
opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||||
@@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
|
|||||||
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitMfcc() {
|
void InitMfcc() {
|
||||||
mfcc_opts_.frame_opts.dither = config_.dither;
|
mfcc_opts_.frame_opts.dither = config_.dither;
|
||||||
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
|
||||||
@@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
|
|||||||
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitWhisper() {
|
||||||
|
config_.normalize_samples = true;
|
||||||
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
|
opts_.mel_opts.num_bins = config_.feature_dim;
|
||||||
|
|
||||||
|
knf::WhisperFeatureOptions whisper_opts;
|
||||||
|
whisper_opts.frame_opts = opts_.frame_opts;
|
||||||
|
whisper_opts.dim = config_.feature_dim;
|
||||||
|
|
||||||
|
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
|
||||||
|
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||||
|
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
knf::MfccOptions mfcc_opts_;
|
knf::MfccOptions mfcc_opts_;
|
||||||
FeatureExtractorConfig config_;
|
FeatureExtractorConfig config_;
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ struct FeatureExtractorConfig {
|
|||||||
|
|
||||||
bool is_mfcc = false;
|
bool is_mfcc = false;
|
||||||
|
|
||||||
|
bool is_whisper = false;
|
||||||
|
|
||||||
bool round_to_power_of_two = true;
|
bool round_to_power_of_two = true;
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ class OnlineCtcModel {
|
|||||||
|
|
||||||
// Return true if the model supports batch size > 1
|
// Return true if the model supports batch size > 1
|
||||||
virtual bool SupportBatchProcessing() const { return true; }
|
virtual bool SupportBatchProcessing() const { return true; }
|
||||||
|
|
||||||
|
virtual bool UseWhisperFeature() const { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||||
@@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
config_.feat_config.normalize_samples = false;
|
config_.feat_config.normalize_samples = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_->UseWhisperFeature()) {
|
||||||
|
config_.feat_config.is_whisper = true;
|
||||||
|
}
|
||||||
|
|
||||||
InitDecoder();
|
InitDecoder();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
config_.feat_config.normalize_samples = false;
|
config_.feat_config.normalize_samples = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_->UseWhisperFeature()) {
|
||||||
|
config_.feat_config.is_whisper = true;
|
||||||
|
}
|
||||||
|
|
||||||
InitDecoder();
|
InitDecoder();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||||
std::vector<float> features =
|
std::vector<float> features =
|
||||||
ss[i]->GetFrames(num_processed_frames, chunk_length);
|
ss[i]->GetFrames(num_processed_frames, chunk_length);
|
||||||
|
if (config_.feat_config.is_whisper) {
|
||||||
|
OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length,
|
||||||
|
feat_dim);
|
||||||
|
}
|
||||||
|
|
||||||
// Question: should num_processed_frames include chunk_shift?
|
// Question: should num_processed_frames include chunk_shift?
|
||||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||||
@@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
const auto num_processed_frames = s->GetNumProcessedFrames();
|
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||||
std::vector<float> frames =
|
std::vector<float> frames =
|
||||||
s->GetFrames(num_processed_frames, chunk_length);
|
s->GetFrames(num_processed_frames, chunk_length);
|
||||||
|
|
||||||
|
if (config_.feat_config.is_whisper) {
|
||||||
|
OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length,
|
||||||
|
feat_dim);
|
||||||
|
}
|
||||||
|
|
||||||
s->GetNumProcessedFrames() += chunk_shift;
|
s->GetNumProcessedFrames() += chunk_shift;
|
||||||
|
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
|
|||||||
@@ -19,34 +19,51 @@ class OnlineStream::Impl {
|
|||||||
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
|
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InputFinished() const { feat_extractor_.InputFinished(); }
|
void InputFinished() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
feat_extractor_.InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
int32_t NumFramesReady() const {
|
int32_t NumFramesReady() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsLastFrame(int32_t frame) const {
|
bool IsLastFrame(int32_t frame) const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return feat_extractor_.IsLastFrame(frame);
|
return feat_extractor_.IsLastFrame(frame);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
// we don't reset the feature extractor
|
// we don't reset the feature extractor
|
||||||
start_frame_index_ += num_processed_frames_;
|
start_frame_index_ += num_processed_frames_;
|
||||||
num_processed_frames_ = 0;
|
num_processed_frames_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
int32_t &GetNumProcessedFrames() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return num_processed_frames_;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
|
int32_t GetNumFramesSinceStart() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return start_frame_index_;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t &GetCurrentSegment() { return segment_; }
|
int32_t &GetCurrentSegment() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return segment_;
|
||||||
|
}
|
||||||
|
|
||||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||||
|
|
||||||
@@ -125,6 +142,7 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
FeatureExtractor feat_extractor_;
|
FeatureExtractor feat_extractor_;
|
||||||
|
mutable std::mutex mutex_;
|
||||||
/// For contextual-biasing
|
/// For contextual-biasing
|
||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
int32_t num_processed_frames_ = 0; // before subsampling
|
int32_t num_processed_frames_ = 0; // before subsampling
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
|
|
||||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||||
|
|
||||||
|
bool UseWhisperFeature() const { return use_whisper_feature_; }
|
||||||
|
|
||||||
OrtAllocator *Allocator() { return allocator_; }
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
// Return a vector containing 3 tensors
|
// Return a vector containing 3 tensors
|
||||||
@@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||||
|
|
||||||
|
std::string feature_type;
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
|
||||||
|
if (feature_type == "whisper") {
|
||||||
|
use_whisper_feature_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto shape =
|
auto shape =
|
||||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
|
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
|
||||||
@@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl {
|
|||||||
int32_t T_ = 0;
|
int32_t T_ = 0;
|
||||||
int32_t decode_chunk_len_ = 0;
|
int32_t decode_chunk_len_ = 0;
|
||||||
int32_t vocab_size_ = 0;
|
int32_t vocab_size_ = 0;
|
||||||
|
|
||||||
|
// for models from
|
||||||
|
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
|
||||||
|
bool use_whisper_feature_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
|
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
|
||||||
@@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const {
|
|||||||
return impl_->ChunkShift();
|
return impl_->ChunkShift();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OnlineZipformer2CtcModel::UseWhisperFeature() const {
|
||||||
|
return impl_->UseWhisperFeature();
|
||||||
|
}
|
||||||
|
|
||||||
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
|
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
|
||||||
return impl_->Allocator();
|
return impl_->Allocator();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel {
|
|||||||
// before we process the next chunk.
|
// before we process the next chunk.
|
||||||
int32_t ChunkShift() const override;
|
int32_t ChunkShift() const override;
|
||||||
|
|
||||||
|
bool UseWhisperFeature() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ for a list of pre-trained models to download.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback,
|
if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback,
|
||||||
nullptr /* user_data */)) {
|
s.get())) {
|
||||||
fprintf(stderr, "portaudio error: %d\n", device_index);
|
fprintf(stderr, "portaudio error: %d\n", device_index);
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user