// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h // // Copyright (c) 2022-2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ #include #include #include #include #include #include #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/offline-whisper-decoder.h" #include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-whisper-model.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/transpose.h" namespace sherpa_onnx { static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, const SymbolTable &sym_table) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); std::string text; for (auto i : src.tokens) { if (!sym_table.Contains(i)) { continue; } const auto &s = sym_table[i]; text += s; r.tokens.push_back(s); } r.text = text; r.lang = src.lang; return r; } class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config.model_config)) { Init(); } template OfflineRecognizerWhisperImpl(Manager *mgr, const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(mgr, config), config_(config), symbol_table_(mgr, config_.model_config.tokens), model_( std::make_unique(mgr, config.model_config)) { Init(); } void Init() { // tokens.txt from whisper is base64 encoded, so we need to decode it symbol_table_.ApplyBase64Decode(); if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique( config_.model_config.whisper, model_.get()); } else { SHERPA_ONNX_LOGE( "Only greedy_search is supported at present for whisper. Given %s", config_.decoding_method.c_str()); exit(-1); } } std::unique_ptr CreateStream() const override { WhisperTag tag; tag.dim = model_->FeatureDim(); return std::make_unique(tag); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { // batch decoding is not implemented yet for (int32_t i = 0; i != n; ++i) { DecodeStream(ss[i]); } } void SetConfig(const OfflineRecognizerConfig &config) override { config_.model_config.whisper = config.model_config.whisper; } OfflineRecognizerConfig GetConfig() const override { return config_; } private: void DecodeStream(OfflineStream *s) const { decoder_->SetConfig(config_.model_config.whisper); int32_t max_num_frames = 3000; auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); int32_t feat_dim = s->FeatureDim(); std::vector f = s->GetFrames(); int32_t num_frames = f.size() / feat_dim; // we use 50 here so that there will be some zero tail paddings if (num_frames >= max_num_frames - 50) { SHERPA_ONNX_LOGE( "Only waves less than 30 seconds are supported. We process only the " "first 30 seconds and discard the remaining data"); num_frames = max_num_frames - 50; } model_->NormalizeFeatures(f.data(), num_frames, feat_dim); // note that 1000 is an experience-value. // You can replace 1000 by other values, say, 100. // // Since we have removed the 30 seconds constraint, we need // tail_padding_frames so that whisper is able to detect the eot token. int32_t tail_padding_frames = 1000; if (config_.model_config.whisper.tail_paddings > 0) { tail_padding_frames = config_.model_config.whisper.tail_paddings; } int32_t actual_frames = std::min(num_frames + tail_padding_frames, max_num_frames); std::array shape{1, actual_frames, feat_dim}; Ort::Value mel = Ort::Value::CreateTensor( model_->Allocator(), shape.data(), shape.size()); float *p_mel = mel.GetTensorMutableData(); std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); std::fill_n(p_mel + num_frames * feat_dim, (actual_frames - num_frames) * feat_dim, 0); mel = Transpose12(model_->Allocator(), &mel); try { auto cross_kv = model_->ForwardEncoder(std::move(mel)); auto results = decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); auto r = Convert(results[0], symbol_table_); r.text = ApplyInverseTextNormalization(std::move(r.text)); s->SetResult(r); } catch (const Ort::Exception &ex) { SHERPA_ONNX_LOGE( "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " "input frames: %d, Current tail " "paddings: %d. If you see a lot of such exceptions, please consider " "using a larger --whisper-tail-paddings", ex.what(), num_frames, tail_padding_frames); return; } } private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; std::unique_ptr model_; std::unique_ptr decoder_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_