Add endpointing (#54)

This commit is contained in:
Fangjun Kuang
2023-02-22 15:35:55 +08:00
committed by GitHub
parent 1c6f79f096
commit 124384369a
23 changed files with 2190 additions and 21 deletions

View File

@@ -2,12 +2,13 @@ from pathlib import Path
from typing import List
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
def _assert_file_exists(f: str):
@@ -26,6 +27,10 @@ class OnlineRecognizer(object):
num_threads: int = 4,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20,
):
"""
Please refer to
@@ -52,6 +57,22 @@ class OnlineRecognizer(object):
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
@@ -72,10 +93,18 @@ class OnlineRecognizer(object):
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
tokens=tokens,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
)
self.recognizer = _Recognizer(recognizer_config)
@@ -93,4 +122,10 @@ class OnlineRecognizer(object):
return self.recognizer.is_ready(s)
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text
return self.recognizer.get_result(s).text.strip()
def is_endpoint(self, s: OnlineStream) -> bool:
return self.recognizer.is_endpoint(s)
def reset(self, s: OnlineStream) -> bool:
return self.recognizer.reset(s)