Support streaming zipformer CTC (#496)

* Support streaming zipformer CTC

* test online zipformer2 CTC

* Update doc of sherpa-onnx.cc

* Add Python APIs for streaming zipformer2 ctc

* Add Python API examples for streaming zipformer2 ctc

* Swift API for streaming zipformer2 CTC

* NodeJS API for streaming zipformer2 CTC

* Kotlin API for streaming zipformer2 CTC

* Golang API for streaming zipformer2 CTC

* C# API for streaming zipformer2 CTC

* Release v1.9.6
This commit is contained in:
Fangjun Kuang
2023-12-22 13:46:33 +08:00
committed by GitHub
parent 7634f5f034
commit e475e750ac
70 changed files with 1517 additions and 211 deletions

View File

@@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx
online-stream.cc
online-transducer-model-config.cc
online-wenet-ctc-model-config.cc
online-zipformer2-ctc-model-config.cc
sherpa-onnx.cc
silero-vad-model-config.cc
vad-model-config.cc

View File

@@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

View File

@@ -12,6 +12,7 @@
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
@@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m);
PybindOnlineParaformerModelConfig(m);
PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &, const std::string &,
const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

View File

@@ -0,0 +1,22 @@
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m) {
using PyClass = OnlineZipformer2CtcModelConfig;
py::class_<PyClass>(*m, "OnlineZipformer2CtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_

View File

@@ -8,11 +8,14 @@ from _sherpa_onnx import (
OnlineLMConfig,
OnlineModelConfig,
OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineZipformer2CtcModelConfig,
)
@@ -272,6 +275,101 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_zipformer2_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
model_config = OnlineModelConfig(
zipformer2_ctc=zipformer2_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
@@ -352,7 +450,6 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="wenet_ctc",
)
feat_config = FeatureExtractorConfig(

View File

@@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_zipformer2_ctc(self):
m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13"
for use_int8 in [True, False]:
name = (
"ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"
if use_int8
else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav"
wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav"
wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav"
if not Path(model).is_file():
print("skipping test_zipformer2_ctc()")
return
print(f"testing {model}")
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",