Refactor online recognizer (#250)

* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
Fangjun Kuang
2023-08-09 20:27:31 +08:00
committed by GitHub
parent 6061318e3f
commit 79c2ce5dd4
40 changed files with 670 additions and 480 deletions

View File

@@ -15,6 +15,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
@@ -27,11 +28,11 @@ class OnlineTransducerModel {
virtual ~OnlineTransducerModel() = default;
static std::unique_ptr<OnlineTransducerModel> Create(
const OnlineTransducerModelConfig &config);
const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OnlineTransducerModel> Create(
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
AAssetManager *mgr, const OnlineModelConfig &config);
#endif
/** Stack a list of individual states into a batch.
@@ -64,15 +65,15 @@ class OnlineTransducerModel {
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states Encoder state of the previous chunk. It is changed in-place.
* @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t.
* @param processed_frames Processed frames before subsampling. It is a 1-D
* tensor with data type int64_t.
*
* @return Return a tuple containing:
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
* - next_states Encoder state for the next chunk.
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) = 0; // NOLINT
/** Run the decoder network.