Support batch greedy search decoding (#30)
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineTransducerDecoderResult;
|
||||
|
||||
class OnlineTransducerModel {
|
||||
public:
|
||||
virtual ~OnlineTransducerModel() = default;
|
||||
@@ -27,8 +29,8 @@ class OnlineTransducerModel {
|
||||
* @param states states[i] contains the state for the i-th utterance.
|
||||
* @return Return a single value representing the batched state.
|
||||
*/
|
||||
virtual Ort::Value StackStates(
|
||||
const std::vector<Ort::Value> &states) const = 0;
|
||||
virtual std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const = 0;
|
||||
|
||||
/** Unstack a batch state into a list of individual states.
|
||||
*
|
||||
@@ -37,7 +39,8 @@ class OnlineTransducerModel {
|
||||
* @param states A batched state.
|
||||
* @return ans[i] contains the state for the i-th utterance.
|
||||
*/
|
||||
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
|
||||
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const = 0;
|
||||
|
||||
/** Get the initial encoder states.
|
||||
*
|
||||
@@ -58,7 +61,8 @@ class OnlineTransducerModel {
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||
|
||||
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
|
||||
virtual Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
@@ -111,6 +115,7 @@ class OnlineTransducerModel {
|
||||
virtual int32_t VocabSize() const = 0;
|
||||
|
||||
virtual int32_t SubsamplingFactor() const { return 4; }
|
||||
virtual OrtAllocator *Allocator() = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user