Add endpointing (#54)
This commit is contained in:
@@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "OnlineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -47,7 +49,8 @@ class OnlineRecognizer::Impl {
|
||||
explicit Impl(const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.tokens) {
|
||||
sym_(config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
@@ -64,7 +67,7 @@ class OnlineRecognizer::Impl {
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
@@ -111,18 +114,44 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
return Convert(decoder_result, sym_);
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const {
|
||||
if (!config_.enable_endpoint) return false;
|
||||
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||
|
||||
// frame shift is 10 milliseconds
|
||||
float frame_shift_in_seconds = 0.01;
|
||||
|
||||
// subsampling factor is 4
|
||||
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
|
||||
|
||||
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||
frame_shift_in_seconds);
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const {
|
||||
// reset result and neural network model state,
|
||||
// but keep the feature extractor state
|
||||
|
||||
// reset result
|
||||
s->SetResult(decoder_->GetEmptyResult());
|
||||
|
||||
// reset neural network model state
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
Endpoint endpoint_;
|
||||
};
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||
@@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
|
||||
return impl_->GetResult(s);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
|
||||
return impl_->IsEndpoint(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user