Add endpointing (#54)

This commit is contained in:
Fangjun Kuang
2023-02-22 15:35:55 +08:00
committed by GitHub
parent 1c6f79f096
commit 124384369a
23 changed files with 2190 additions and 21 deletions

View File

@@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
return os.str();
}
@@ -47,7 +49,8 @@ class OnlineRecognizer::Impl {
explicit Impl(const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.tokens) {
sym_(config.tokens),
endpoint_(config_.endpoint_config) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
}
@@ -64,7 +67,7 @@ class OnlineRecognizer::Impl {
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) {
void DecodeStreams(OnlineStream **ss, int32_t n) const {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
@@ -111,18 +114,44 @@ class OnlineRecognizer::Impl {
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) {
OnlineRecognizerResult GetResult(OnlineStream *s) const {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_);
}
bool IsEndpoint(OnlineStream *s) const {
if (!config_.enable_endpoint) return false;
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
// subsampling factor is 4
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const {
// reset result and neural network model state,
// but keep the feature extractor state
// reset result
s->SetResult(decoder_->GetEmptyResult());
// reset neural network model state
s->SetStates(model_->GetEncoderInitStates());
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_;
Endpoint endpoint_;
};
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
@@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
return impl_->GetResult(s);
}
bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
return impl_->IsEndpoint(s);
}
void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
} // namespace sherpa_onnx