add endpointing for online websocket server (#294)
This commit is contained in:
@@ -26,7 +26,8 @@ namespace sherpa_onnx {
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor) {
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.segment = segment;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment());
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const override {
|
||||
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const override {
|
||||
{
|
||||
// segment is incremented only when the last
|
||||
// result is not empty
|
||||
const auto &r = s->GetResult();
|
||||
if (!r.tokens.empty() && r.tokens.back() != 0) {
|
||||
s->GetCurrentSegment() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// we keep the decoder_out
|
||||
decoder_->UpdateDecoderOut(&s->GetResult());
|
||||
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
||||
|
||||
Reference in New Issue
Block a user