This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/online-zipformer-transducer-model.h
Jingzhao Ou 0992063de8 Stack and streaming conformer support (#141)
* added csrc/stack.cc

* stack: added checks

* added copyright info

* passed cpp style checks

* formatted code

* added some support for streaming conformer model support (not verified)

* code lint

* made more progress with streaming conformer support (not working yet)

* passed style check

* changes as suggested by @csukuangfj

* added some debug info

* fixed style check

* Use Cat to replace Stack

* remove debug statements

---------

Co-authored-by: Jingzhao Ou (jou2019) <jou2019@cisco.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
2023-05-11 14:30:39 +08:00

108 lines
3.3 KiB
C++

// sherpa-onnx/csrc/online-zipformer-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
class OnlineZipformerTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineZipformerTransducerModel(
const OnlineTransducerModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineZipformerTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
#endif
std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
int32_t ContextSize() const override { return context_size_; }
int32_t ChunkSize() const override { return T_; }
int32_t ChunkShift() const override { return decode_chunk_len_; }
int32_t VocabSize() const override { return vocab_size_; }
OrtAllocator *Allocator() override { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length);
void InitDecoder(void *model_data, size_t model_data_length);
void InitJoiner(void *model_data, size_t model_data_length);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
OnlineTransducerModelConfig config_;
std::vector<int32_t> encoder_dims_;
std::vector<int32_t> attention_dims_;
std::vector<int32_t> num_encoder_layers_;
std::vector<int32_t> cnn_module_kernels_;
std::vector<int32_t> left_context_len_;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_