Support streaming zipformer CTC (#496)

* Support streaming zipformer CTC

* test online zipformer2 CTC

* Update doc of sherpa-onnx.cc

* Add Python APIs for streaming zipformer2 ctc

* Add Python API examples for streaming zipformer2 ctc

* Swift API for streaming zipformer2 CTC

* NodeJS API for streaming zipformer2 CTC

* Kotlin API for streaming zipformer2 CTC

* Golang API for streaming zipformer2 CTC

* C# API for streaming zipformer2 CTC

* Release v1.9.6
This commit is contained in:
Fangjun Kuang
2023-12-22 13:46:33 +08:00
committed by GitHub
parent 7634f5f034
commit e475e750ac
70 changed files with 1517 additions and 211 deletions

View File

@@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_zipformer2_ctc(self):
m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13"
for use_int8 in [True, False]:
name = (
"ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"
if use_int8
else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav"
wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav"
wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav"
if not Path(model).is_file():
print("skipping test_zipformer2_ctc()")
return
print(f"testing {model}")
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",