Refactor online recognizer (#250)
* Refactor online recognizer. Make it easier to support other streaming models. Note that it is a breaking change for the Python API. `sherpa_onnx.OnlineRecognizer()` used before should be replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
@@ -27,11 +28,11 @@ class OnlineTransducerModel {
|
||||
virtual ~OnlineTransducerModel() = default;
|
||||
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
const OnlineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
|
||||
AAssetManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
/** Stack a list of individual states into a batch.
|
||||
@@ -64,15 +65,15 @@ class OnlineTransducerModel {
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param states Encoder state of the previous chunk. It is changed in-place.
|
||||
* @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t.
|
||||
* @param processed_frames Processed frames before subsampling. It is a 1-D
|
||||
* tensor with data type int64_t.
|
||||
*
|
||||
* @return Return a tuple containing:
|
||||
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
|
||||
* - next_states Encoder state for the next chunk.
|
||||
*/
|
||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value features, std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) = 0; // NOLINT
|
||||
|
||||
/** Run the decoder network.
|
||||
|
||||
Reference in New Issue
Block a user