diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 625d02b1..27f58687 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -26,7 +26,8 @@ namespace sherpa_onnx { static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, const SymbolTable &sym_table, int32_t frame_shift_ms, - int32_t subsampling_factor) { + int32_t subsampling_factor, + int32_t segment) { OnlineRecognizerResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); @@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.timestamps.push_back(time); } + r.segment = segment; + return r; } @@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = 4; - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment()); } bool IsEndpoint(OnlineStream *s) const override { @@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } void Reset(OnlineStream *s) const override { + { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetResult(); + if (!r.tokens.empty() && r.tokens.back() != 0) { + s->GetCurrentSegment() += 1; + } + } + // we keep the decoder_out decoder_->UpdateDecoderOut(&s->GetResult()); Ort::Value decoder_out = std::move(s->GetResult().decoder_out); diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 8960ed13..39dfd796 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -43,6 +43,8 @@ class OnlineStream::Impl { int32_t &GetNumProcessedFrames() { return num_processed_frames_; } + int32_t &GetCurrentSegment() { return segment_; } + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } OnlineTransducerDecoderResult &GetResult() { return result_; } @@ -83,6 +85,7 @@ class OnlineStream::Impl { ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling int32_t start_frame_index_ = 0; // never reset + int32_t segment_ = 0; OnlineTransducerDecoderResult result_; std::vector states_; std::vector paraformer_feat_cache_; @@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { return impl_->GetNumProcessedFrames(); } +int32_t &OnlineStream::GetCurrentSegment() { + return impl_->GetCurrentSegment(); +} + void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { impl_->SetResult(r); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index ae920c1d..6b7a96c4 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -68,6 +68,8 @@ class OnlineStream { // The returned reference is valid as long as this object is alive. int32_t &GetNumProcessedFrames(); + int32_t &GetCurrentSegment(); + void SetResult(const OnlineTransducerDecoderResult &r); OnlineTransducerDecoderResult &GetResult(); diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc index a62bef25..9c265de8 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.cc +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() { for (auto c : c_vec) { auto result = recognizer_->GetResult(c->s.get()); + if (recognizer_->IsEndpoint(c->s.get())) { + recognizer_->Reset(c->s.get()); + } asio::post(server_->GetConnectionContext(), [this, hdl = c->hdl, str = result.AsJsonString()]() {