add modified beam search (#69)
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -71,9 +73,6 @@ class OnlineTransducerModel {
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> states) = 0; // NOLINT
|
||||
|
||||
virtual Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
* Caution: We assume there are no recurrent connections in the decoder and
|
||||
@@ -125,7 +124,13 @@ class OnlineTransducerModel {
|
||||
virtual int32_t VocabSize() const = 0;
|
||||
|
||||
virtual int32_t SubsamplingFactor() const { return 4; }
|
||||
|
||||
virtual OrtAllocator *Allocator() = 0;
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results);
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &hyps);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user