Add Python APIs for WeNet CTC models (#428)

This commit is contained in:
Fangjun Kuang
2023-11-16 14:20:41 +08:00
committed by GitHub
parent fac4f6bc7c
commit 049fb9f451
13 changed files with 538 additions and 11 deletions

View File

@@ -9,15 +9,16 @@ from _sherpa_onnx import (
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineTdnnModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
OfflineRecognizerConfig,
OfflineStream,
OfflineTdnnModelConfig,
OfflineTransducerModelConfig,
OfflineWenetCtcModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
@@ -389,6 +390,70 @@ class OfflineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
wenet_ctc=OfflineWenetCtcModelConfig(model=model),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="wenet_ctc",
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()

View File

@@ -12,6 +12,7 @@ from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
)
@@ -140,13 +141,13 @@ class OnlineRecognizer(object):
"Please use --decoding-method=modified_beam_search when using "
f"--hotwords-file. Currently given: {decoding_method}"
)
if lm and decoding_method != "modified_beam_search":
raise ValueError(
"Please use --decoding-method=modified_beam_search when using "
f"--lm. Currently given: {decoding_method}"
)
lm_config = OnlineLMConfig(
model=lm,
scale=lm_scale,
@@ -271,6 +272,112 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
tokens: str,
model: str,
chunk_size: int = 16,
num_left_chunks: int = 4,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
chunk_size:
The --chunk-size parameter from WeNet.
num_left_chunks:
The --num-left-chunks parameter from WeNet.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
wenet_ctc_config = OnlineWenetCtcModelConfig(
model=model,
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
)
model_config = OnlineModelConfig(
wenet_ctc=wenet_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="wenet_ctc",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()

View File

@@ -267,6 +267,53 @@ class TestOfflineRecognizer(unittest.TestCase):
print(s1.result.text)
print(s2.result.text)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
"sherpa-onnx-zh-wenet-aishell2",
"sherpa-onnx-zh-wenet-wenetspeech",
"sherpa-onnx-zh-wenet-multi-cn",
"sherpa-onnx-en-wenet-librispeech",
"sherpa-onnx-en-wenet-gigaspeech",
]
for m in models:
for use_int8 in [True, False]:
name = "model.int8.onnx" if use_int8 else "model.onnx"
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/0.wav"
wave1 = f"{d}/{m}/test_wavs/1.wav"
wave2 = f"{d}/{m}/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_wenet_ctc()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
samples0, sample_rate0 = read_wave(wave0)
s0.accept_waveform(sample_rate0, samples0)
s1 = recognizer.create_stream()
samples1, sample_rate1 = read_wave(wave1)
s1.accept_waveform(sample_rate1, samples1)
s2 = recognizer.create_stream()
samples2, sample_rate2 = read_wave(wave2)
s2.accept_waveform(sample_rate2, samples2)
recognizer.decode_streams([s0, s1, s2])
print(s0.result.text)
print(s1.result.text)
print(s2.result.text)
if __name__ == "__main__":
unittest.main()

View File

@@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
"sherpa-onnx-zh-wenet-aishell2",
"sherpa-onnx-zh-wenet-wenetspeech",
"sherpa-onnx-zh-wenet-multi-cn",
"sherpa-onnx-en-wenet-librispeech",
"sherpa-onnx-en-wenet-gigaspeech",
]
for m in models:
for use_int8 in [True, False]:
name = (
"model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/0.wav"
wave1 = f"{d}/{m}/test_wavs/1.wav"
wave2 = f"{d}/{m}/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_wenet_ctc()")
return
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
if __name__ == "__main__":
unittest.main()