add endpointing for online websocket server (#294)

This commit is contained in:
Fangjun Kuang
2023-08-31 14:41:04 +08:00
committed by GitHub
parent 2b0152d2a2
commit a0a747a0c0
4 changed files with 27 additions and 2 deletions

View File

@@ -26,7 +26,8 @@ namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
int32_t frame_shift_ms,
int32_t subsampling_factor) {
int32_t subsampling_factor,
int32_t segment) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.push_back(time);
}
r.segment = segment;
return r;
}
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment());
}
bool IsEndpoint(OnlineStream *s) const override {
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
void Reset(OnlineStream *s) const override {
{
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
s->GetCurrentSegment() += 1;
}
}
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);