Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
|
||||
OnlineTransducerDecoderResult &&other);
|
||||
};
|
||||
|
||||
class OnlineStream;
|
||||
class OnlineTransducerDecoder {
|
||||
public:
|
||||
virtual ~OnlineTransducerDecoder() = default;
|
||||
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
|
||||
virtual void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||
|
||||
/** Run transducer beam search given the output from the encoder model.
|
||||
*
|
||||
* Note: Currently this interface is for contextual-biasing feature which
|
||||
* needs a ContextGraph owned by the OnlineStream.
|
||||
*
|
||||
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||
* @param ss A list of OnlineStreams.
|
||||
* @param result It is modified in-place.
|
||||
*
|
||||
* @note There is no need to pass encoder_out_length here since for the
|
||||
* online decoding case, each utterance has the same number of frames
|
||||
* and there are no paddings.
|
||||
*/
|
||||
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// used for endpointing. We need to keep decoder_out after reset
|
||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user