diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 43b3ec37..cd09f785 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,15 +8,20 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +mkdir -p /tmp/icefall-models +dir=/tmp/icefall-models +log "Test streaming transducer models" + +pushd $dir repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 log "Start testing ${repo_url}" -repo=$(basename $repo_url) +repo=$dir/$(basename $repo_url) log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -pushd $repo +cd $repo git lfs pull --include "*.onnx" popd @@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \ $repo/test_wavs/0.wav \ $repo/test_wavs/1.wav \ $repo/test_wavs/2.wav \ - $repo/test_wavs/3.wav + $repo/test_wavs/3.wav \ + $repo/test_wavs/8k.wav + +python3 ./python-api-examples/online-decode-files.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/8k.wav + +python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose + +log "Test non-streaming transducer models" + +pushd $dir +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-04-01 + +log "Start testing ${repo_url}" +repo=$dir/$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +cd $repo +git lfs pull --include "*.onnx" +popd + +ls -lh $repo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + +log "Test non-streaming paraformer models" + +pushd $dir +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 + +log "Start testing ${repo_url}" +repo=$dir/$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +cd $repo +git lfs pull --include "*.onnx" +popd + +ls -lh $repo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose diff --git a/.gitignore b/.gitignore index 800713d2..dcba6799 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,4 @@ a.sh run-offline-websocket-client-*.sh run-sherpa-onnx-*.sh sherpa-onnx-zipformer-en-2023-03-30 +sherpa-onnx-zipformer-en-2023-04-01 diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py old mode 100644 new mode 100755 index ed08c393..e41aa01c --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -46,6 +46,7 @@ from typing import Tuple import numpy as np import sherpa_onnx + def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: samples_float32 = samples_float32 / 32768 return samples_float32, f.getframerate() + def main(): args = get_args() assert_file_exists(args.tokens) @@ -183,7 +185,7 @@ def main(): sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, - debug=args.debug + debug=args.debug, ) else: assert_file_exists(args.paraformer) @@ -194,10 +196,9 @@ def main(): sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, - debug=args.debug + debug=args.debug, ) - print("Started!") start_time = time.time() @@ -212,12 +213,8 @@ def main(): s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) - s.accept_waveform(sample_rate, tail_paddings) - streams.append(s) - recognizer.decode_streams(streams) results = [s.result.text for s in streams] end_time = time.time() diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index bdfb9fd8..7f804684 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -18,8 +18,8 @@ namespace sherpa_onnx { void FeatureExtractorConfig::Register(ParseOptions *po) { po->Register("sample-rate", &sampling_rate, - "Sampling rate of the input waveform. Must match the one " - "expected by the model. Note: You can have a different " + "Sampling rate of the input waveform. " + "Note: You can have a different " "sample rate for the input waveform. We will do resampling " "inside the feature extractor"); diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 28ed642f..bfb9fb64 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -17,8 +17,8 @@ namespace sherpa_onnx { void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { po->Register("sample-rate", &sampling_rate, - "Sampling rate of the input waveform. Must match the one " - "expected by the model. Note: You can have a different " + "Sampling rate of the input waveform. " + "Note: You can have a different " "sample rate for the input waveform. We will do resampling " "inside the feature extractor"); diff --git a/sherpa-onnx/csrc/offline-websocket-server.cc b/sherpa-onnx/csrc/offline-websocket-server.cc index fb7a45ea..eb55413d 100644 --- a/sherpa-onnx/csrc/offline-websocket-server.cc +++ b/sherpa-onnx/csrc/offline-websocket-server.cc @@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) { po.Register("port", &port, "The port on which the server will listen."); config.Register(&po); + po.DisableOption("sample-rate"); if (argc == 1) { po.PrintUsage(); diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index f84ad977..f4371e7a 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -18,20 +18,25 @@ def _assert_file_exists(f: str): class OfflineRecognizer(object): - """A class for offline speech recognition.""" + """A class for offline speech recognition. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_offline_recognizer.py + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/offline-decode-files.py + """ @classmethod def from_transducer( - cls, - encoder: str, - decoder: str, - joiner: str, - tokens: str, - num_threads: int, - sample_rate: int = 16000, - feature_dim: int = 80, - decoding_method: str = "greedy_search", - debug: bool = False, + cls, + encoder: str, + decoder: str, + joiner: str, + tokens: str, + num_threads: int, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, ): """ Please refer to @@ -59,7 +64,7 @@ class OfflineRecognizer(object): feature_dim: Dimension of the feature used to train the model. decoding_method: - Valid values are greedy_search, modified_beam_search. + Support only greedy_search for now. debug: True to show debug messages. """ @@ -68,14 +73,12 @@ class OfflineRecognizer(object): transducer=OfflineTransducerModelConfig( encoder_filename=encoder, decoder_filename=decoder, - joiner_filename=joiner - ), - paraformer=OfflineParaformerModelConfig( - model="" + joiner_filename=joiner, ), + paraformer=OfflineParaformerModelConfig(model=""), tokens=tokens, num_threads=num_threads, - debug=debug + debug=debug, ) feat_config = OfflineFeatureExtractorConfig( @@ -93,14 +96,14 @@ class OfflineRecognizer(object): @classmethod def from_paraformer( - cls, - paraformer: str, - tokens: str, - num_threads: int, - sample_rate: int = 16000, - feature_dim: int = 80, - decoding_method: str = "greedy_search", - debug: bool = False, + cls, + paraformer: str, + tokens: str, + num_threads: int, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, ): """ Please refer to @@ -131,16 +134,12 @@ class OfflineRecognizer(object): self = cls.__new__(cls) model_config = OfflineModelConfig( transducer=OfflineTransducerModelConfig( - encoder_filename="", - decoder_filename="", - joiner_filename="" - ), - paraformer=OfflineParaformerModelConfig( - model=paraformer + encoder_filename="", decoder_filename="", joiner_filename="" ), + paraformer=OfflineParaformerModelConfig(model=paraformer), tokens=tokens, num_threads=num_threads, - debug=debug + debug=debug, ) feat_config = OfflineFeatureExtractorConfig( @@ -164,4 +163,3 @@ class OfflineRecognizer(object): def decode_streams(self, ss: List[OfflineStream]): self.recognizer.decode_streams(ss) - diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index ce1a6afa..bf8ca089 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -17,7 +17,12 @@ def _assert_file_exists(f: str): class OnlineRecognizer(object): - """A class for streaming speech recognition.""" + """A class for streaming speech recognition. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py + """ def __init__( self, diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt index c53a09f1..ff9b8c9e 100644 --- a/sherpa-onnx/python/tests/CMakeLists.txt +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -18,6 +18,8 @@ endfunction() # please sort the files in alphabetic order set(py_test_files test_feature_extractor_config.py + test_offline_recognizer.py + test_online_recognizer.py test_online_transducer_model_config.py ) diff --git a/sherpa-onnx/python/tests/test_offline_recognizer.py b/sherpa-onnx/python/tests/test_offline_recognizer.py new file mode 100755 index 00000000..5f9924d9 --- /dev/null +++ b/sherpa-onnx/python/tests/test_offline_recognizer.py @@ -0,0 +1,201 @@ +# sherpa-onnx/python/tests/test_offline_recognizer.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_offline_recognizer_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +d = "/tmp/icefall-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html +# and +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestOfflineRecognizer(unittest.TestCase): + def test_transducer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx" + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt" + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_single_file()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_transducer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx" + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt" + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/1.wav" + wave2 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/8k.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_multiple_files()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + + def test_paraformer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" + else: + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx" + + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav" + + if not Path(model).is_file(): + print("skipping test_paraformer_single_file()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + num_threads=1, + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_paraformer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" + else: + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx" + + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav" + wave2 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav" + wave3 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_paraformer_multiple_files()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + num_threads=1, + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + s3 = recognizer.create_stream() + samples3, sample_rate3 = read_wave(wave3) + s3.accept_waveform(sample_rate3, samples3) + + recognizer.decode_streams([s0, s1, s2, s3]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + print(s3.result.text) + + +if __name__ == "__main__": + unittest.main() diff --git a/sherpa-onnx/python/tests/test_online_recognizer.py b/sherpa-onnx/python/tests/test_online_recognizer.py new file mode 100755 index 00000000..157cfd8d --- /dev/null +++ b/sherpa-onnx/python/tests/test_online_recognizer.py @@ -0,0 +1,146 @@ +# sherpa-onnx/python/tests/test_online_recognizer.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_online_recognizer_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +d = "/tmp/icefall-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestOnlineRecognizer(unittest.TestCase): + def test_transducer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx" + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_single_file()") + return + + for decoding_method in ["greedy_search", "modified_beam_search"]: + recognizer = sherpa_onnx.OnlineRecognizer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + decoding_method=decoding_method, + ) + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + + s.input_finished() + while recognizer.is_ready(s): + recognizer.decode_stream(s) + print(recognizer.get_result(s)) + + def test_transducer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx" + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" + wave2 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/2.wav" + wave3 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/3.wav" + wave4 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/8k.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_multiple_files()") + return + + for decoding_method in ["greedy_search", "modified_beam_search"]: + recognizer = sherpa_onnx.OnlineRecognizer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + decoding_method=decoding_method, + ) + streams = [] + waves = [wave0, wave1, wave2, wave3, wave4] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + +if __name__ == "__main__": + unittest.main()