Add endpointing (#54)
This commit is contained in:
@@ -2,12 +2,13 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizer as _Recognizer,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
|
||||
|
||||
def _assert_file_exists(f: str):
|
||||
@@ -26,6 +27,10 @@ class OnlineRecognizer(object):
|
||||
num_threads: int = 4,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
enable_endpoint_detection: bool = False,
|
||||
rule1_min_trailing_silence: int = 2.4,
|
||||
rule2_min_trailing_silence: int = 1.2,
|
||||
rule3_min_utterance_length: int = 20,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -52,6 +57,22 @@ class OnlineRecognizer(object):
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
enable_endpoint_detection:
|
||||
True to enable endpoint detection. False to disable endpoint
|
||||
detection.
|
||||
rule1_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If the duration
|
||||
of trailing silence in seconds is larger than this value, we assume
|
||||
an endpoint is detected.
|
||||
rule2_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If we have decoded
|
||||
something that is nonsilence and if the duration of trailing silence
|
||||
in seconds is larger than this value, we assume an endpoint is
|
||||
detected.
|
||||
rule3_min_utterance_length:
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -72,10 +93,18 @@ class OnlineRecognizer(object):
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
rule1_min_trailing_silence=rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
tokens=tokens,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -93,4 +122,10 @@ class OnlineRecognizer(object):
|
||||
return self.recognizer.is_ready(s)
|
||||
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text
|
||||
return self.recognizer.get_result(s).text.strip()
|
||||
|
||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_endpoint(s)
|
||||
|
||||
def reset(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.reset(s)
|
||||
|
||||
Reference in New Issue
Block a user