Add Python APIs for WeNet CTC models (#428)

This commit is contained in:
Fangjun Kuang
2023-11-16 14:20:41 +08:00
committed by GitHub
parent fac4f6bc7c
commit 049fb9f451
13 changed files with 538 additions and 11 deletions

View File

@@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
"sherpa-onnx-zh-wenet-aishell2",
"sherpa-onnx-zh-wenet-wenetspeech",
"sherpa-onnx-zh-wenet-multi-cn",
"sherpa-onnx-en-wenet-librispeech",
"sherpa-onnx-en-wenet-gigaspeech",
]
for m in models:
for use_int8 in [True, False]:
name = (
"model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/0.wav"
wave1 = f"{d}/{m}/test_wavs/1.wav"
wave2 = f"{d}/{m}/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_wenet_ctc()")
return
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
if __name__ == "__main__":
unittest.main()