Add C++ support for streaming NeMo CTC models. (#857)

This commit is contained in:
Fangjun Kuang
2024-05-10 16:26:43 +08:00
committed by GitHub
parent 1eb60e8711
commit 46e4e5b7ac
22 changed files with 782 additions and 41 deletions

View File

@@ -12,9 +12,11 @@ from _sherpa_onnx import (
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineNeMoCtcModelConfig,
OnlineZipformer2CtcModelConfig,
OnlineCtcFstDecoderConfig,
)
@@ -59,6 +61,7 @@ class OnlineRecognizer(object):
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
debug: bool = False,
):
"""
Please refer to
@@ -154,6 +157,7 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type=model_type,
debug=debug,
)
feat_config = FeatureExtractorConfig(
@@ -220,6 +224,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
@@ -283,6 +288,7 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type="paraformer",
debug=debug,
)
feat_config = FeatureExtractorConfig(
@@ -324,6 +330,7 @@ class OnlineRecognizer(object):
ctc_graph: str = "",
ctc_max_active: int = 3000,
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
@@ -386,6 +393,7 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
@@ -417,6 +425,106 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_nemo_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
debug:
True to show meta data in the model.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
nemo_ctc_config = OnlineNeMoCtcModelConfig(
model=model,
)
model_config = OnlineModelConfig(
nemo_ctc=nemo_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
@@ -433,6 +541,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
@@ -497,6 +606,7 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
@@ -537,6 +647,9 @@ class OnlineRecognizer(object):
def is_ready(self, s: OnlineStream) -> bool:
return self.recognizer.is_ready(s)
def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult:
return self.recognizer.get_result(s)
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text.strip()