Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

This commit is contained in:
Fangjun Kuang
2024-07-12 23:47:39 +08:00
committed by GitHub
parent d928f77d0e
commit 117cd7bb8c
23 changed files with 152 additions and 85 deletions

View File

@@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
WhisperTag tag;
tag.dim = model_->FeatureDim();
return std::make_unique<OfflineStream>(tag);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {

View File

@@ -97,12 +97,16 @@ class OfflineStream::Impl {
}
}
explicit Impl(WhisperTag /*tag*/) {
explicit Impl(WhisperTag tag) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
opts_.mel_opts.num_bins = tag.dim;
knf::WhisperFeatureOptions whisper_opts;
whisper_opts.frame_opts = opts_.frame_opts;
whisper_opts.dim = tag.dim;
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}

View File

@@ -35,7 +35,10 @@ struct OfflineRecognitionResult {
std::string AsJsonString() const;
};
struct WhisperTag {};
struct WhisperTag {
int32_t dim = 80;
};
struct CEDTag {};
class OfflineStream {

View File

@@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl {
int32_t VocabSize() const { return n_vocab_; }
int32_t FeatureDim() const { return n_mels_; }
int32_t Translate() const { return translate_; }
bool IsMultiLingual() const { return is_multilingual_; }
@@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl {
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels");
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
@@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl {
std::unordered_map<int32_t, std::string> id2lang_;
// model meta data
int32_t n_mels_ = 80;
int32_t n_text_layer_ = 0;
int32_t n_text_ctx_ = 0;
int32_t n_text_state_ = 0;
@@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); }
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
bool OfflineWhisperModel::IsMultiLingual() const {

View File

@@ -102,6 +102,7 @@ class OfflineWhisperModel {
int32_t SOT() const;
int32_t TextCtx() const;
int32_t VocabSize() const;
int32_t FeatureDim() const;
int32_t Translate() const;
bool IsMultiLingual() const;