// sherpa-onnx/csrc/online-nemo-ctc-model.h // // Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ #define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-ctc-model.h" #include "sherpa-onnx/csrc/online-model-config.h" namespace sherpa_onnx { class OnlineNeMoCtcModel : public OnlineCtcModel { public: explicit OnlineNeMoCtcModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); #endif ~OnlineNeMoCtcModel() override; // A list of 3 tensors: // - cache_last_channel // - cache_last_time // - cache_last_channel_len std::vector GetInitStates() const override; std::vector StackStates( std::vector> states) const override; std::vector> UnStackStates( std::vector states) const override; /** * * @param x A 3-D tensor of shape (N, T, C). N has to be 1. * @param states It is from GetInitStates() or returned from this method. * * @return Return a list of tensors * - ans[0] contains log_probs, of shape (N, T, C) * - ans[1:] contains next_states */ std::vector Forward( Ort::Value x, std::vector states) const override; /** Return the vocabulary size of the model */ int32_t VocabSize() const override; /** Return an allocator for allocating memory */ OrtAllocator *Allocator() const override; // The model accepts this number of frames before subsampling as input int32_t ChunkLength() const override; // Similar to frame_shift in feature extractor, after processing // ChunkLength() frames, we advance by ChunkShift() frames // before we process the next chunk. int32_t ChunkShift() const override; bool SupportBatchProcessing() const override { return true; } private: class Impl; std::unique_ptr impl_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_