add endpointing for online websocket server (#294)
This commit is contained in:
@@ -26,7 +26,8 @@ namespace sherpa_onnx {
|
|||||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||||
const SymbolTable &sym_table,
|
const SymbolTable &sym_table,
|
||||||
int32_t frame_shift_ms,
|
int32_t frame_shift_ms,
|
||||||
int32_t subsampling_factor) {
|
int32_t subsampling_factor,
|
||||||
|
int32_t segment) {
|
||||||
OnlineRecognizerResult r;
|
OnlineRecognizerResult r;
|
||||||
r.tokens.reserve(src.tokens.size());
|
r.tokens.reserve(src.tokens.size());
|
||||||
r.timestamps.reserve(src.tokens.size());
|
r.timestamps.reserve(src.tokens.size());
|
||||||
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
|||||||
r.timestamps.push_back(time);
|
r.timestamps.push_back(time);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.segment = segment;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
// TODO(fangjun): Remember to change these constants if needed
|
// TODO(fangjun): Remember to change these constants if needed
|
||||||
int32_t frame_shift_ms = 10;
|
int32_t frame_shift_ms = 10;
|
||||||
int32_t subsampling_factor = 4;
|
int32_t subsampling_factor = 4;
|
||||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
|
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
|
s->GetCurrentSegment());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsEndpoint(OnlineStream *s) const override {
|
bool IsEndpoint(OnlineStream *s) const override {
|
||||||
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Reset(OnlineStream *s) const override {
|
void Reset(OnlineStream *s) const override {
|
||||||
|
{
|
||||||
|
// segment is incremented only when the last
|
||||||
|
// result is not empty
|
||||||
|
const auto &r = s->GetResult();
|
||||||
|
if (!r.tokens.empty() && r.tokens.back() != 0) {
|
||||||
|
s->GetCurrentSegment() += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// we keep the decoder_out
|
// we keep the decoder_out
|
||||||
decoder_->UpdateDecoderOut(&s->GetResult());
|
decoder_->UpdateDecoderOut(&s->GetResult());
|
||||||
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
||||||
|
|
||||||
|
int32_t &GetCurrentSegment() { return segment_; }
|
||||||
|
|
||||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||||
|
|
||||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||||
@@ -83,6 +85,7 @@ class OnlineStream::Impl {
|
|||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
int32_t num_processed_frames_ = 0; // before subsampling
|
int32_t num_processed_frames_ = 0; // before subsampling
|
||||||
int32_t start_frame_index_ = 0; // never reset
|
int32_t start_frame_index_ = 0; // never reset
|
||||||
|
int32_t segment_ = 0;
|
||||||
OnlineTransducerDecoderResult result_;
|
OnlineTransducerDecoderResult result_;
|
||||||
std::vector<Ort::Value> states_;
|
std::vector<Ort::Value> states_;
|
||||||
std::vector<float> paraformer_feat_cache_;
|
std::vector<float> paraformer_feat_cache_;
|
||||||
@@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
|
|||||||
return impl_->GetNumProcessedFrames();
|
return impl_->GetNumProcessedFrames();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t &OnlineStream::GetCurrentSegment() {
|
||||||
|
return impl_->GetCurrentSegment();
|
||||||
|
}
|
||||||
|
|
||||||
void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
|
void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
|
||||||
impl_->SetResult(r);
|
impl_->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ class OnlineStream {
|
|||||||
// The returned reference is valid as long as this object is alive.
|
// The returned reference is valid as long as this object is alive.
|
||||||
int32_t &GetNumProcessedFrames();
|
int32_t &GetNumProcessedFrames();
|
||||||
|
|
||||||
|
int32_t &GetCurrentSegment();
|
||||||
|
|
||||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||||
OnlineTransducerDecoderResult &GetResult();
|
OnlineTransducerDecoderResult &GetResult();
|
||||||
|
|
||||||
|
|||||||
@@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() {
|
|||||||
|
|
||||||
for (auto c : c_vec) {
|
for (auto c : c_vec) {
|
||||||
auto result = recognizer_->GetResult(c->s.get());
|
auto result = recognizer_->GetResult(c->s.get());
|
||||||
|
if (recognizer_->IsEndpoint(c->s.get())) {
|
||||||
|
recognizer_->Reset(c->s.get());
|
||||||
|
}
|
||||||
|
|
||||||
asio::post(server_->GetConnectionContext(),
|
asio::post(server_->GetConnectionContext(),
|
||||||
[this, hdl = c->hdl, str = result.AsJsonString()]() {
|
[this, hdl = c->hdl, str = result.AsJsonString()]() {
|
||||||
|
|||||||
Reference in New Issue
Block a user