diff --git a/python-api-examples/online-websocket-client-decode-file.py b/python-api-examples/online-websocket-client-decode-file.py new file mode 100755 index 00000000..7c2fc05b --- /dev/null +++ b/python-api-examples/online-websocket-client-decode-file.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-online-websocket-server + +Usage: + ./online-websocket-client-decode-file.py \ + --server-addr localhost \ + --server-port 6006 \ + --seconds-per-message 0.1 \ + --samples-per-message 8000 \ + /path/to/foo.wav + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc + +Note: The server is implemented in C++. + +There is also a C++ version of the client. Please see +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc +""" + +import argparse +import asyncio +import logging +import time +import wave + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def read_wave(wave_filename: str) -> np.ndarray: + """ + Args: + wave_filename: + Path to a wave file. Its sampling rate has to be 16000. + It should be single channel and each sample should be 16-bit. + Returns: + Return a 1-D float32 tensor. + """ + + with wave.open(wave_filename) as f: + assert f.getframerate() == 16000, f.getframerate() + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32 + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "--samples-per-message", + type=int, + default=8000, + help="Number of samples per message", + ) + + parser.add_argument( + "--seconds-per-message", + type=float, + default=0.1, + help="We will simulate that the duration of two messages is of this value", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file. Must be wave with a single channel, 16kHz " + "sampling rate, 16-bit of each sample.", + ) + + return parser.parse_args() + + +async def receive_results(socket: websockets.WebSocketServerProtocol): + last_message = "" + async for message in socket: + if message != "Done!": + last_message = message + logging.info(message) + else: + return last_message + + +async def run( + server_addr: str, + server_port: int, + wave_filename: str, + samples_per_message: int, + seconds_per_message: float, +): + data = read_wave(wave_filename) + + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + logging.info(f"Sending {wave_filename}") + + receive_task = asyncio.create_task(receive_results(websocket)) + + start = 0 + while start < data.shape[0]: + end = start + samples_per_message + end = min(end, data.shape[0]) + d = data.data[start:end].tobytes() + + await websocket.send(d) + + await asyncio.sleep(seconds_per_message) # in seconds + + start += samples_per_message + + # to signal that the client has sent all the data + await websocket.send("Done") + + decoding_results = await receive_task + logging.info(f"\nFinal result is:\n{decoding_results}") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + samples_per_message = args.samples_per_message + seconds_per_message = args.seconds_per_message + + await run( + server_addr=server_addr, + server_port=server_port, + wave_filename=args.sound_file, + samples_per_message=samples_per_message, + seconds_per_message=seconds_per_message, + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/python-api-examples/online-websocket-client-microphone.py b/python-api-examples/online-websocket-client-microphone.py new file mode 100755 index 00000000..4d8966c4 --- /dev/null +++ b/python-api-examples/online-websocket-client-microphone.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-online-websocket-server + +Usage: + ./online-websocket-client-microphone.py \ + --server-addr localhost \ + --server-port 6006 + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc + +Note: The server is implemented in C++. + +There is also a C++ version of the client. Please see +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc +""" + +import argparse +import asyncio +import time + +import numpy as np + +try: + import sounddevice as sd +except ImportError as e: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + sys.exit(-1) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + return parser.parse_args() + + +async def inputstream_generator(channels=1): + """Generator that yields blocks of input data as NumPy arrays. + + See https://python-sounddevice.readthedocs.io/en/0.4.6/examples.html#creating-an-asyncio-generator-for-audio-blocks + """ + q_in = asyncio.Queue() + loop = asyncio.get_event_loop() + + def callback(indata, frame_count, time_info, status): + loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status)) + + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + print() + print("Started! Please speak") + + stream = sd.InputStream( + callback=callback, + channels=channels, + dtype="float32", + samplerate=16000, + blocksize=int(0.05 * 16000), # 0.05 seconds + ) + with stream: + while True: + indata, status = await q_in.get() + yield indata, status + + +async def receive_results(socket: websockets.WebSocketServerProtocol): + last_message = "" + async for message in socket: + if message != "Done!": + if last_message != message: + last_message = message + + if last_message: + print(last_message) + else: + return last_message + + +async def run( + server_addr: str, + server_port: int, +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + receive_task = asyncio.create_task(receive_results(websocket)) + print("Started! Please Speak") + + async for indata, status in inputstream_generator(): + if status: + print(status) + indata = indata.reshape(-1) + indata = np.ascontiguousarray(indata) + await websocket.send(indata.tobytes()) + + decoding_results = await receive_task + print("\nFinal result is:\n{decoding_results}") + + +async def main(): + args = get_args() + print(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + + await run( + server_addr=server_addr, + server_port=server_port, + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 92fd53c1..265f0d1e 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -1,3 +1,4 @@ +# Copyright (c) 2023 Xiaomi Corporation from pathlib import Path from typing import List