Support streaming zipformer CTC (#496)
* Support streaming zipformer CTC * test online zipformer2 CTC * Update doc of sherpa-onnx.cc * Add Python APIs for streaming zipformer2 ctc * Add Python API examples for streaming zipformer2 ctc * Swift API for streaming zipformer2 CTC * NodeJS API for streaming zipformer2 CTC * Kotlin API for streaming zipformer2 CTC * Golang API for streaming zipformer2 CTC * C# API for streaming zipformer2 CTC * Release v1.9.6
This commit is contained in:
@@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
|
||||
def test_zipformer2_ctc(self):
|
||||
m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13"
|
||||
for use_int8 in [True, False]:
|
||||
name = (
|
||||
"ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"
|
||||
if use_int8
|
||||
else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
|
||||
)
|
||||
model = f"{d}/{m}/{name}"
|
||||
tokens = f"{d}/{m}/tokens.txt"
|
||||
wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav"
|
||||
wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav"
|
||||
wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav"
|
||||
if not Path(model).is_file():
|
||||
print("skipping test_zipformer2_ctc()")
|
||||
return
|
||||
print(f"testing {model}")
|
||||
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
num_threads=1,
|
||||
provider="cpu",
|
||||
)
|
||||
|
||||
streams = []
|
||||
waves = [wave0, wave1, wave2]
|
||||
for wave in waves:
|
||||
s = recognizer.create_stream()
|
||||
samples, sample_rate = read_wave(wave)
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
streams.append(s)
|
||||
|
||||
while True:
|
||||
ready_list = []
|
||||
for s in streams:
|
||||
if recognizer.is_ready(s):
|
||||
ready_list.append(s)
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
recognizer.decode_streams(ready_list)
|
||||
|
||||
results = [recognizer.get_result(s) for s in streams]
|
||||
for wave_filename, result in zip(waves, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
|
||||
def test_wenet_ctc(self):
|
||||
models = [
|
||||
"sherpa-onnx-zh-wenet-aishell",
|
||||
|
||||
Reference in New Issue
Block a user