// sherpa-onnx/csrc/online-zipformer2-transducer-model.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ #define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ #include #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { class OnlineZipformer2TransducerModel : public OnlineTransducerModel { public: explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineZipformer2TransducerModel(AAssetManager *mgr, const OnlineModelConfig &config); #endif std::vector StackStates( const std::vector> &states) const override; std::vector> UnStackStates( const std::vector &states) const override; std::vector GetEncoderInitStates() override; std::pair> RunEncoder( Ort::Value features, std::vector states, Ort::Value processed_frames) override; Ort::Value RunDecoder(Ort::Value decoder_input) override; Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; int32_t ContextSize() const override { return context_size_; } int32_t ChunkSize() const override { return T_; } int32_t ChunkShift() const override { return decode_chunk_len_; } int32_t VocabSize() const override { return vocab_size_; } OrtAllocator *Allocator() override { return allocator_; } private: void InitEncoder(void *model_data, size_t model_data_length); void InitDecoder(void *model_data, size_t model_data_length); void InitJoiner(void *model_data, size_t model_data_length); private: Ort::Env env_; Ort::SessionOptions sess_opts_; Ort::AllocatorWithDefaultOptions allocator_; std::unique_ptr encoder_sess_; std::unique_ptr decoder_sess_; std::unique_ptr joiner_sess_; std::vector encoder_input_names_; std::vector encoder_input_names_ptr_; std::vector encoder_output_names_; std::vector encoder_output_names_ptr_; std::vector decoder_input_names_; std::vector decoder_input_names_ptr_; std::vector decoder_output_names_; std::vector decoder_output_names_ptr_; std::vector joiner_input_names_; std::vector joiner_input_names_ptr_; std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; OnlineModelConfig config_; std::vector encoder_dims_; std::vector query_head_dims_; std::vector value_head_dims_; std::vector num_heads_; std::vector num_encoder_layers_; std::vector cnn_module_kernels_; std::vector left_context_len_; int32_t T_ = 0; int32_t decode_chunk_len_ = 0; int32_t context_size_ = 0; int32_t vocab_size_ = 0; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_