add endpointing for online websocket server (#294)

This commit is contained in:
Fangjun Kuang
2023-08-31 14:41:04 +08:00
committed by GitHub
parent 2b0152d2a2
commit a0a747a0c0
4 changed files with 27 additions and 2 deletions

View File

@@ -26,7 +26,8 @@ namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table, const SymbolTable &sym_table,
int32_t frame_shift_ms, int32_t frame_shift_ms,
int32_t subsampling_factor) { int32_t subsampling_factor,
int32_t segment) {
OnlineRecognizerResult r; OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size()); r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size());
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.push_back(time); r.timestamps.push_back(time);
} }
r.segment = segment;
return r; return r;
} }
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed // TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10; int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4; int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment());
} }
bool IsEndpoint(OnlineStream *s) const override { bool IsEndpoint(OnlineStream *s) const override {
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
} }
void Reset(OnlineStream *s) const override { void Reset(OnlineStream *s) const override {
{
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
s->GetCurrentSegment() += 1;
}
}
// we keep the decoder_out // we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult()); decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out); Ort::Value decoder_out = std::move(s->GetResult().decoder_out);

View File

@@ -43,6 +43,8 @@ class OnlineStream::Impl {
int32_t &GetNumProcessedFrames() { return num_processed_frames_; } int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
int32_t &GetCurrentSegment() { return segment_; }
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
OnlineTransducerDecoderResult &GetResult() { return result_; } OnlineTransducerDecoderResult &GetResult() { return result_; }
@@ -83,6 +85,7 @@ class OnlineStream::Impl {
ContextGraphPtr context_graph_; ContextGraphPtr context_graph_;
int32_t num_processed_frames_ = 0; // before subsampling int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset int32_t start_frame_index_ = 0; // never reset
int32_t segment_ = 0;
OnlineTransducerDecoderResult result_; OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_; std::vector<Ort::Value> states_;
std::vector<float> paraformer_feat_cache_; std::vector<float> paraformer_feat_cache_;
@@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
return impl_->GetNumProcessedFrames(); return impl_->GetNumProcessedFrames();
} }
int32_t &OnlineStream::GetCurrentSegment() {
return impl_->GetCurrentSegment();
}
void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r); impl_->SetResult(r);
} }

View File

@@ -68,6 +68,8 @@ class OnlineStream {
// The returned reference is valid as long as this object is alive. // The returned reference is valid as long as this object is alive.
int32_t &GetNumProcessedFrames(); int32_t &GetNumProcessedFrames();
int32_t &GetCurrentSegment();
void SetResult(const OnlineTransducerDecoderResult &r); void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult(); OnlineTransducerDecoderResult &GetResult();

View File

@@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() {
for (auto c : c_vec) { for (auto c : c_vec) {
auto result = recognizer_->GetResult(c->s.get()); auto result = recognizer_->GetResult(c->s.get());
if (recognizer_->IsEndpoint(c->s.get())) {
recognizer_->Reset(c->s.get());
}
asio::post(server_->GetConnectionContext(), asio::post(server_->GetConnectionContext(),
[this, hdl = c->hdl, str = result.AsJsonString()]() { [this, hdl = c->hdl, str = result.AsJsonString()]() {