Begin to support CTC models (#119)

Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
Fangjun Kuang
2023-04-07 23:11:34 +08:00
committed by GitHub
parent 9ac747248b
commit 80060c276d
40 changed files with 1244 additions and 60 deletions

View File

@@ -4,12 +4,15 @@ from typing import List
from _sherpa_onnx import (
OfflineFeatureExtractorConfig,
OfflineRecognizer as _Recognizer,
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
OfflineRecognizerConfig,
OfflineStream,
OfflineModelConfig,
OfflineTransducerModelConfig,
OfflineParaformerModelConfig,
)
@@ -75,7 +78,6 @@ class OfflineRecognizer(object):
decoder_filename=decoder,
joiner_filename=joiner,
),
paraformer=OfflineParaformerModelConfig(model=""),
tokens=tokens,
num_threads=num_threads,
debug=debug,
@@ -119,7 +121,7 @@ class OfflineRecognizer(object):
symbol integer_id
paraformer:
Path to ``paraformer.onnx``.
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
@@ -133,9 +135,6 @@ class OfflineRecognizer(object):
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
transducer=OfflineTransducerModelConfig(
encoder_filename="", decoder_filename="", joiner_filename=""
),
paraformer=OfflineParaformerModelConfig(model=paraformer),
tokens=tokens,
num_threads=num_threads,
@@ -155,6 +154,64 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
@classmethod
def from_nemo_ctc(
cls,
model: str,
tokens: str,
num_threads: int,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model),
tokens=tokens,
num_threads=num_threads,
debug=debug,
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self):
return self.recognizer.create_stream()