// sherpa-onnx/csrc/offline-whisper-model.h // // Copyright (c) 2022-2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ #include #include #include #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-model-config.h" namespace sherpa_onnx { class OfflineWhisperModel { public: explicit OfflineWhisperModel(const OfflineModelConfig &config); #if __ANDROID_API__ >= 9 OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config); #endif ~OfflineWhisperModel(); /** Run the encoder model. * * @param features A tensor of shape (N, C, T). It is changed in-place. * C is 80 and T is 3000. * * @return Return a pair containing: * - n_layer_cross_k: A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state) * - n_layer_cross_v: A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state) */ std::pair ForwardEncoder(Ort::Value features) const; /** Run the decoder model. * * @param tokens A int64 tensor of shape (N, num_words) * @param n_layer_self_k_cache A 4-D tensor of shape * (n_text_layer, N, n_text_ctx, n_text_state). * @param n_layer_self_v_cache A 4-D tensor of shape * (n_text_layer, N, n_text_ctx, n_text_state). * @param n_layer_cross_k A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state). * @param n_layer_cross_v A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state). * @param offset A int64 tensor of shape (N,) * * @return Return a tuple containing 6 tensors: * * - logits A 3-D tensor of shape (N, num_words, vocab_size) * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache * - out_n_layer_cross_k Same as n_layer_cross_k * - out_n_layer_cross_v Same as n_layer_cross_v * - out_offset Same as offset */ std::tuple ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, Ort::Value offset) const; int32_t DetectLanguage() const; /** Return the initial self kv cache in a pair * - n_layer_self_k_cache A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state). * - n_layer_self_v_cache A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state). */ std::pair GetInitialSelfKVCache() const; const std::vector &GetInitialTokens() const; const std::vector &GetAllLanguageIDs() const; const std::unordered_map &GetLang2ID() const; const std::unordered_map &GetID2Lang() const; /** Return an allocator for allocating memory */ OrtAllocator *Allocator() const; int32_t NoTimeStampsToken() const; int32_t EOT() const; int32_t SOT() const; int32_t TextCtx() const; int32_t VocabSize() const; int32_t Translate() const; bool IsMultiLingual() const; private: class Impl; std::unique_ptr impl_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_