Support batch greedy search decoding (#30)

This commit is contained in:
Fangjun Kuang
2023-02-19 15:04:24 +08:00
committed by GitHub
parent ebc3b47fb8
commit 8acc059b3f
5 changed files with 181 additions and 68 deletions

View File

@@ -13,6 +13,8 @@
namespace sherpa_onnx {
class OnlineTransducerDecoderResult;
class OnlineTransducerModel {
public:
virtual ~OnlineTransducerModel() = default;
@@ -27,8 +29,8 @@ class OnlineTransducerModel {
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual Ort::Value StackStates(
const std::vector<Ort::Value> &states) const = 0;
virtual std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const = 0;
/** Unstack a batch state into a list of individual states.
*
@@ -37,7 +39,8 @@ class OnlineTransducerModel {
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const = 0;
/** Get the initial encoder states.
*
@@ -58,7 +61,8 @@ class OnlineTransducerModel {
Ort::Value features,
std::vector<Ort::Value> &states) = 0; // NOLINT
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
virtual Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
/** Run the decoder network.
*
@@ -111,6 +115,7 @@ class OnlineTransducerModel {
virtual int32_t VocabSize() const = 0;
virtual int32_t SubsamplingFactor() const { return 4; }
virtual OrtAllocator *Allocator() = 0;
};
} // namespace sherpa_onnx