Refactor online recognizer (#250)

* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
Fangjun Kuang
2023-08-09 20:27:31 +08:00
committed by GitHub
parent 6061318e3f
commit 79c2ce5dd4
40 changed files with 670 additions and 480 deletions

View File

@@ -15,19 +15,18 @@
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
public:
explicit OnlineZipformer2TransducerModel(
const OnlineTransducerModelConfig &config);
explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineZipformer2TransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
const OnlineModelConfig &config);
#endif
std::vector<Ort::Value> StackStates(
@@ -87,7 +86,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
OnlineTransducerModelConfig config_;
OnlineModelConfig config_;
std::vector<int32_t> encoder_dims_;
std::vector<int32_t> query_head_dims_;