diff --git a/.flake8 b/.flake8 index 87510d70..4dace19e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] show-source=true statistics=true -max-line-length = 80 +max-line-length = 120 exclude = .git, diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index ca903f43..43b3ec37 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -30,9 +30,12 @@ ls -lh ls -lh $repo -python3 ./python-api-examples/decode-file.py \ +python3 ./python-api-examples/online-decode-files.py \ --tokens=$repo/tokens.txt \ --encoder=$repo/encoder-epoch-99-avg-1.onnx \ --decoder=$repo/decoder-epoch-99-avg-1.onnx \ --joiner=$repo/joiner-epoch-99-avg-1.onnx \ - --wave-filename=$repo/test_wavs/4.wav + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav diff --git a/.gitignore b/.gitignore index 41d1a0a1..8704f21d 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example run-sherpa-onnx-offline-paraformer.sh run-sherpa-onnx-offline-transducer.sh sherpa-onnx-paraformer-zh-2023-03-28 +run-offline-websocket-server-paraformer.sh diff --git a/python-api-examples/offline-websocket-client-decode-files-paralell.py b/python-api-examples/offline-websocket-client-decode-files-paralell.py new file mode 100755 index 00000000..d1d691a2 --- /dev/null +++ b/python-api-examples/offline-websocket-client-decode-files-paralell.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-offline-websocket-server + +This file shows how to transcribe multiple +files in parallel. We create a separate connection for transcribing each file. + +Usage: + ./offline-websocket-client-decode-files-parallel.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/16kHz.wav \ + /path/to/8kHz.wav + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc + +Note: The server is implemented in C++. +""" + +import argparse +import asyncio +import logging +import wave +from typing import Tuple + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +async def run( + server_addr: str, + server_port: int, + wave_filename: str, +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + logging.info(f"Sending {wave_filename}") + samples, sample_rate = read_wave(wave_filename) + assert isinstance(sample_rate, int) + assert samples.dtype == np.float32, samples.dtype + assert samples.ndim == 1, samples.dim + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes + buf += (samples.size * 4).to_bytes(4, byteorder="little") + buf += samples.tobytes() + + await websocket.send(buf) + + decoding_results = await websocket.recv() + logging.info(f"{wave_filename}\n{decoding_results}") + + # to signal that the client has sent all the data + await websocket.send("Done") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + sound_files = args.sound_files + + all_tasks = [] + for wave_filename in sound_files: + task = asyncio.create_task( + run( + server_addr=server_addr, + server_port=server_port, + wave_filename=wave_filename, + ) + ) + all_tasks.append(task) + + await asyncio.gather(*all_tasks) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/python-api-examples/offline-websocket-client-decode-files-sequential.py b/python-api-examples/offline-websocket-client-decode-files-sequential.py new file mode 100755 index 00000000..935226e4 --- /dev/null +++ b/python-api-examples/offline-websocket-client-decode-files-sequential.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-offline-websocket-server + +This file shows how to use a single connection to transcribe multiple +files sequentially. + +Usage: + ./offline-websocket-client-decode-files-sequential.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/16kHz.wav \ + /path/to/8kHz.wav + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc + +Note: The server is implemented in C++. +""" + +import argparse +import asyncio +import logging +import wave +from typing import List, Tuple + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +async def run( + server_addr: str, + server_port: int, + sound_files: List[str], +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + for wave_filename in sound_files: + logging.info(f"Sending {wave_filename}") + samples, sample_rate = read_wave(wave_filename) + assert isinstance(sample_rate, int) + assert samples.dtype == np.float32, samples.dtype + assert samples.ndim == 1, samples.dim + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes + buf += (samples.size * 4).to_bytes(4, byteorder="little") + buf += samples.tobytes() + + await websocket.send(buf) + + decoding_results = await websocket.recv() + print(decoding_results) + + # to signal that the client has sent all the data + await websocket.send("Done") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + sound_files = args.sound_files + + await run( + server_addr=server_addr, + server_port=server_port, + sound_files=sound_files, + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/python-api-examples/decode-file.py b/python-api-examples/online-decode-files.py similarity index 54% rename from python-api-examples/decode-file.py rename to python-api-examples/online-decode-files.py index f4bdd327..a61ed300 100755 --- a/python-api-examples/decode-file.py +++ b/python-api-examples/online-decode-files.py @@ -1,8 +1,15 @@ #!/usr/bin/env python3 """ -This file demonstrates how to use sherpa-onnx Python API to recognize -a single file. +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a streaming model. + +Usage: + ./online-decode-files.py \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/16kHz.wav \ + /path/to/8kHz.wav Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html @@ -13,17 +20,12 @@ import argparse import time import wave from pathlib import Path +from typing import Tuple import numpy as np import sherpa_onnx -def assert_file_exists(filename: str): - assert Path( - filename - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" - - def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -68,26 +70,58 @@ def get_args(): ) parser.add_argument( - "--wave-filename", + "sound_files", type=str, - help="""Path to the wave filename. - Should have a single channel with 16-bit samples. - It does not need to be 16kHz. It can have any sampling rate. - """, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", ) return parser.parse_args() +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + def main(): args = get_args() assert_file_exists(args.encoder) assert_file_exists(args.decoder) assert_file_exists(args.joiner) assert_file_exists(args.tokens) - if not Path(args.wave_filename).is_file(): - print(f"{args.wave_filename} does not exist!") - return recognizer = sherpa_onnx.OnlineRecognizer( tokens=args.tokens, @@ -99,42 +133,44 @@ def main(): feature_dim=80, decoding_method=args.decoding_method, ) - with wave.open(args.wave_filename) as f: - # If the wave file has a different sampling rate from the one - # expected by the model (16 kHz in our case), we will do - # resampling inside sherpa-onnx - wave_file_sample_rate = f.getframerate() - assert f.getnchannels() == 1, f.getnchannels() - assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes - num_samples = f.getnframes() - samples = f.readframes(num_samples) - samples_int16 = np.frombuffer(samples, dtype=np.int16) - samples_float32 = samples_int16.astype(np.float32) - - samples_float32 = samples_float32 / 32768 - - duration = len(samples_float32) / wave_file_sample_rate - - start_time = time.time() print("Started!") + start_time = time.time() - stream = recognizer.create_stream() + streams = [] + total_duration = 0 + for wave_filename in args.sound_files: + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration - stream.accept_waveform(wave_file_sample_rate, samples_float32) + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) - tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32) - stream.accept_waveform(wave_file_sample_rate, tail_paddings) + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) - stream.input_finished() + s.input_finished() - while recognizer.is_ready(stream): - recognizer.decode_stream(stream) + streams.append(s) - print(recognizer.get_result(stream)) - - print("Done!") + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + results = [recognizer.get_result(s) for s in streams] end_time = time.time() + print("Done!") + + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + elapsed_seconds = end_time - start_time rtf = elapsed_seconds / duration print(f"num_threads: {args.num_threads}") diff --git a/python-api-examples/online-websocket-client-decode-file.py b/python-api-examples/online-websocket-client-decode-file.py index 7c2fc05b..b03b9928 100755 --- a/python-api-examples/online-websocket-client-decode-file.py +++ b/python-api-examples/online-websocket-client-decode-file.py @@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc import argparse import asyncio import logging -import time import wave try: diff --git a/python-api-examples/online-websocket-client-microphone.py b/python-api-examples/online-websocket-client-microphone.py old mode 100644 new mode 100755 index 8dfc0c50..ab3d5733 --- a/python-api-examples/online-websocket-client-microphone.py +++ b/python-api-examples/online-websocket-client-microphone.py @@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc import argparse import asyncio import sys -import time import numpy as np try: import sounddevice as sd -except ImportError as e: +except ImportError: print("Please install sounddevice first. You can use") print() print(" pip install sounddevice") @@ -134,7 +133,7 @@ async def run( await websocket.send(indata.tobytes()) decoding_results = await receive_task - print("\nFinal result is:\n{decoding_results}") + print(f"\nFinal result is:\n{decoding_results}") async def main(): diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index fbc71e7f..4c5bd633 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -13,7 +13,7 @@ from pathlib import Path try: import sounddevice as sd -except ImportError as e: +except ImportError: print("Please install sounddevice first. You can use") print() print(" pip install sounddevice") @@ -25,9 +25,11 @@ import sherpa_onnx def assert_file_exists(filename: str): - assert Path( - filename - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) def get_args(): diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index 6dc608d1..fe2b0167 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -12,7 +12,7 @@ from pathlib import Path try: import sounddevice as sd -except ImportError as e: +except ImportError: print("Please install sounddevice first. You can use") print() print(" pip install sounddevice") @@ -24,9 +24,11 @@ import sherpa_onnx def assert_file_exists(filename: str): - assert Path( - filename - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) def get_args(): diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 30ec8016..4f594c14 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) ) target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) - add_executable(sherpa-onnx-online-websocket-client online-websocket-client.cc ) @@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) endif() + # For offline websocket + add_executable(sherpa-onnx-offline-websocket-server + offline-websocket-server-impl.cc + offline-websocket-server.cc + ) + target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core) + + if(NOT WIN32) + target_link_libraries(sherpa-onnx-offline-websocket-server -pthread) + target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations) + endif() endif() diff --git a/sherpa-onnx/csrc/offline-websocket-server-impl.cc b/sherpa-onnx/csrc/offline-websocket-server-impl.cc new file mode 100644 index 00000000..88130316 --- /dev/null +++ b/sherpa-onnx/csrc/offline-websocket-server-impl.cc @@ -0,0 +1,285 @@ +// sherpa-onnx/csrc/offline-websocket-server-impl.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) { + recognizer_config.Register(po); + + po->Register("max-batch-size", &max_batch_size, + "Max batch size for decoding."); + + po->Register( + "max-utterance-length", &max_utterance_length, + "Max utterance length in seconds. If we receive an utterance " + "longer than this value, we will reject the connection. " + "If you have enough memory, you can select a large value for it."); +} + +void OfflineWebsocketDecoderConfig::Validate() const { + if (!recognizer_config.Validate()) { + SHERPA_ONNX_LOGE("Error in recongizer config"); + exit(-1); + } + + if (max_batch_size <= 0) { + SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size); + exit(-1); + } + + if (max_utterance_length <= 0) { + SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f", + max_utterance_length); + exit(-1); + } +} + +OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server) + : config_(server->GetConfig().decoder_config), + server_(server), + recognizer_(config_.recognizer_config) {} + +void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) { + std::lock_guard lock(mutex_); + streams_.push_back({hdl, d}); +} + +void OfflineWebsocketDecoder::Decode() { + std::unique_lock lock(mutex_); + if (streams_.empty()) { + return; + } + + int32_t size = + std::min(static_cast(streams_.size()), config_.max_batch_size); + SHERPA_ONNX_LOGE("size: %d", size); + + // We first lock the mutex for streams_, take items from it, and then + // unlock the mutex; in doing so we don't need to lock the mutex to + // access hdl and connection_data later. + std::vector handles(size); + + // Store connection_data here to prevent the data from being freed + // while we are still using it. + std::vector connection_data(size); + + std::vector samples(size); + std::vector samples_length(size); + std::vector> ss(size); + std::vector p_ss(size); + + for (int32_t i = 0; i != size; ++i) { + auto &p = streams_.front(); + handles[i] = p.first; + connection_data[i] = p.second; + streams_.pop_front(); + + auto sample_rate = connection_data[i]->sample_rate; + auto samples = + reinterpret_cast(&connection_data[i]->data[0]); + auto num_samples = connection_data[i]->expected_byte_size / sizeof(float); + auto s = recognizer_.CreateStream(); + s->AcceptWaveform(sample_rate, samples, num_samples); + + ss[i] = std::move(s); + p_ss[i] = ss[i].get(); + } + + lock.unlock(); + + // Note: DecodeStreams is thread-safe + recognizer_.DecodeStreams(p_ss.data(), size); + + for (int32_t i = 0; i != size; ++i) { + connection_hdl hdl = handles[i]; + asio::post(server_->GetConnectionContext(), + [this, hdl, text = ss[i]->GetResult().text]() { + websocketpp::lib::error_code ec; + server_->GetServer().send( + hdl, text, websocketpp::frame::opcode::text, ec); + if (ec) { + server_->GetServer().get_alog().write( + websocketpp::log::alevel::app, ec.message()); + } + }); + } +} + +void OfflineWebsocketServerConfig::Register(ParseOptions *po) { + decoder_config.Register(po); + po->Register("log-file", &log_file, + "Path to the log file. Logs are " + "appended to this file"); +} + +void OfflineWebsocketServerConfig::Validate() const { + decoder_config.Validate(); +} + +OfflineWebsocketServer::OfflineWebsocketServer( + asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OfflineWebsocketServerConfig &config) + : io_conn_(io_conn), + io_work_(io_work), + config_(config), + log_(config.log_file, std::ios::app), + tee_(std::cout, log_), + decoder_(this) { + SetupLog(); + + server_.init_asio(&io_conn_); + + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); + + server_.set_message_handler( + [this](connection_hdl hdl, server::message_ptr msg) { + OnMessage(hdl, msg); + }); +} + +void OfflineWebsocketServer::SetupLog() { + server_.clear_access_channels(websocketpp::log::alevel::all); + server_.set_access_channels(websocketpp::log::alevel::connect); + server_.set_access_channels(websocketpp::log::alevel::disconnect); + + // So that it also prints to std::cout and std::cerr + server_.get_alog().set_ostream(&tee_); + server_.get_elog().set_ostream(&tee_); +} + +void OfflineWebsocketServer::OnOpen(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.emplace(hdl, std::make_shared()); + + SHERPA_ONNX_LOGE("Number of active connections: %d", + static_cast(connections_.size())); +} + +void OfflineWebsocketServer::OnClose(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.erase(hdl); + + SHERPA_ONNX_LOGE("Number of active connections: %d", + static_cast(connections_.size())); +} + +void OfflineWebsocketServer::OnMessage(connection_hdl hdl, + server::message_ptr msg) { + std::unique_lock lock(mutex_); + auto connection_data = connections_.find(hdl)->second; + lock.unlock(); + const std::string &payload = msg->get_payload(); + + switch (msg->get_opcode()) { + case websocketpp::frame::opcode::text: + if (payload == "Done") { + // The client will not send any more data. We can close the + // connection now. + Close(hdl, websocketpp::close::status::normal, "Done"); + } else { + Close(hdl, websocketpp::close::status::normal, + std::string("Invalid payload: ") + payload); + } + break; + + case websocketpp::frame::opcode::binary: { + auto p = reinterpret_cast(payload.data()); + + if (connection_data->expected_byte_size == 0) { + if (payload.size() < 8) { + Close(hdl, websocketpp::close::status::normal, + "Payload is too short"); + break; + } + + connection_data->sample_rate = *reinterpret_cast(p); + + connection_data->expected_byte_size = + *reinterpret_cast(p + 4); + + int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length * + connection_data->sample_rate * sizeof(float); + if (connection_data->expected_byte_size > max_byte_size_) { + float num_samples = + connection_data->expected_byte_size / sizeof(float); + + float duration = num_samples / connection_data->sample_rate; + + std::ostringstream os; + os << "Max utterance length is configured to " + << decoder_.GetConfig().max_utterance_length + << " seconds, received length is " << duration << " seconds. " + << "Payload is too large!"; + Close(hdl, websocketpp::close::status::message_too_big, os.str()); + break; + } + + connection_data->data.resize(connection_data->expected_byte_size); + std::copy(payload.begin() + 8, payload.end(), + connection_data->data.data()); + connection_data->cur = payload.size() - 8; + } else { + std::copy(payload.begin(), payload.end(), + connection_data->data.data() + connection_data->cur); + connection_data->cur += payload.size(); + } + + if (connection_data->expected_byte_size == connection_data->cur) { + auto d = std::make_shared(std::move(*connection_data)); + // Clear it so that we can handle the next audio file from the client. + // The client can send multiple audio files for recognition without + // the need to create another connection. + connection_data->sample_rate = 0; + connection_data->expected_byte_size = 0; + connection_data->cur = 0; + + decoder_.Push(hdl, d); + + connection_data->Clear(); + + asio::post(io_work_, [this]() { decoder_.Decode(); }); + } + break; + } + + default: + // Unexpected message, ignore it + break; + } +} + +void OfflineWebsocketServer::Close(connection_hdl hdl, + websocketpp::close::status::value code, + const std::string &reason) { + auto con = server_.get_con_from_hdl(hdl); + + std::ostringstream os; + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason + << "\n"; + + websocketpp::lib::error_code ec; + server_.close(hdl, code, reason, ec); + if (ec) { + os << "Failed to close" << con->get_remote_endpoint() << ". " + << ec.message() << "\n"; + } + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); +} + +void OfflineWebsocketServer::Run(uint16_t port) { + server_.set_reuse_addr(true); + server_.listen(asio::ip::tcp::v4(), port); + server_.start_accept(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-websocket-server-impl.h b/sherpa-onnx/csrc/offline-websocket-server-impl.h new file mode 100644 index 00000000..9ea88d53 --- /dev/null +++ b/sherpa-onnx/csrc/offline-websocket-server-impl.h @@ -0,0 +1,205 @@ +// sherpa-onnx/csrc/offline-websocket-server-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/tee-stream.h" +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS +#include "websocketpp/server.hpp" + +using server = websocketpp::server; +using connection_hdl = websocketpp::connection_hdl; + +namespace sherpa_onnx { + +/** Communication protocol + * + * The client sends a byte stream to the server. The first 4 bytes in little + * endian indicates the sample rate of the audio data that the client will send. + * The next 4 bytes in little endian indicates the total samples in bytes the + * client will send. The remaining bytes represent audio samples. Each audio + * sample is a float occupying 4 bytes and is normalized into the range + * [-1, 1]. + * + * The byte stream can be broken into arbitrary number of messages. + * We require that the first message has to be at least 8 bytes so that + * we can get `sample_rate` and `expected_byte_size` from the first message. + */ +struct ConnectionData { + // Sample rate of the audio samples the client + int32_t sample_rate; + + // Number of expected bytes sent from the client + int32_t expected_byte_size = 0; + + // Number of bytes received so far + int32_t cur = 0; + + // It saves the received samples from the client. + // We will **reinterpret_cast** it to float. + // We expect that data.size() == expected_byte_size + std::vector data; + + void Clear() { + sample_rate = 0; + expected_byte_size = 0; + cur = 0; + data.clear(); + } +}; + +using ConnectionDataPtr = std::shared_ptr; + +struct OfflineWebsocketDecoderConfig { + OfflineRecognizerConfig recognizer_config; + + int32_t max_batch_size = 5; + + float max_utterance_length = 300; // seconds + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OfflineWebsocketServer; + +class OfflineWebsocketDecoder { + public: + /** + * @param config Configuration for the decoder. + * @param server **Borrowed** from outside. + */ + explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server); + + /** Insert received data to the queue for decoding. + * + * @param hdl A handle to the connection. We can use it to send the result + * back to the client once it finishes decoding. + * @param d The received data + */ + void Push(connection_hdl hdl, ConnectionDataPtr d); + + /** It is called by one of the work thread. + */ + void Decode(); + + const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; } + + private: + OfflineWebsocketDecoderConfig config_; + + /** When we have received all the data from the client, we put it into + * this queue; the worker threads will get items from this queue for + * decoding. + * + * Number of items to take from this queue is determined by + * `--max-batch-size`. If there are not enough items in the queue, we won't + * wait and take whatever we have for decoding. + */ + std::mutex mutex_; + std::deque> streams_; + + OfflineWebsocketServer *server_; // Not owned + OfflineRecognizer recognizer_; +}; + +struct OfflineWebsocketServerConfig { + OfflineWebsocketDecoderConfig decoder_config; + std::string log_file = "./log.txt"; + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OfflineWebsocketServer { + public: + OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OfflineWebsocketServerConfig &config); + + asio::io_context &GetConnectionContext() { return io_conn_; } + server &GetServer() { return server_; } + + void Run(uint16_t port); + + const OfflineWebsocketServerConfig &GetConfig() const { return config_; } + + private: + void SetupLog(); + + // When a websocket client is connected, it will invoke this method + // (Not for HTTP) + void OnOpen(connection_hdl hdl); + + // When a websocket client is disconnected, it will invoke this method + void OnClose(connection_hdl hdl); + + // When a message is received from a websocket client, this method will + // be invoked. + // + // The protocol between the client and the server is as follows: + // + // (1) The client connects to the server + // (2) The client starts to send binary byte stream to the server. + // The byte stream can be broken into multiple messages or it can + // be put into a single message. + // The first message has to contain at least 8 bytes. The first + // 4 bytes in little endian contains a int32_t indicating the + // sampling rate. The next 4 bytes in little endian contains a int32_t + // indicating total number of bytes of samples the client will send. + // We assume each sample is a float containing 4 bytes and has been + // normalized to the range [-1, 1]. + // (4) When the server receives all the samples from the client, it will + // start to decode them. Once decoded, the server sends a text message + // to the client containing the decoded results + // (5) After receiving the decoded results from the server, if the client has + // another audio file to send, it repeats (2), (3), (4) + // (6) If the client has no more audio files to decode, the client sends a + // text message containing "Done" to the server and closes the connection + // (7) The server receives a text message "Done" and closes the connection + // + // Note: + // (a) All models in icefall use features extracted from audio samples + // normalized to the range [-1, 1]. Please send normalized audio samples + // if you use models from icefall. + // (b) Only sound files with a single channel is supported + // (c) Only audio samples are sent. For instance, if we want to decode + // a WAVE file, the RIFF header of the WAVE is not sent. + void OnMessage(connection_hdl hdl, server::message_ptr msg); + + // Close a websocket connection with given code and reason + void Close(connection_hdl hdl, websocketpp::close::status::value code, + const std::string &reason); + + private: + asio::io_context &io_conn_; + asio::io_context &io_work_; + server server_; + + std::map> + connections_; + std::mutex mutex_; + + OfflineWebsocketServerConfig config_; + + std::ofstream log_; + TeeStream tee_; + + OfflineWebsocketDecoder decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-websocket-server.cc b/sherpa-onnx/csrc/offline-websocket-server.cc new file mode 100644 index 00000000..fb7a45ea --- /dev/null +++ b/sherpa-onnx/csrc/offline-websocket-server.cc @@ -0,0 +1,120 @@ +// sherpa-onnx/csrc/offline-websocket-server.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "asio.hpp" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h" +#include "sherpa-onnx/csrc/parse-options.h" + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-onnx using websocket. + +Usage: + +./bin/sherpa-onnx-offline-websocket-server --help + +(1) For transducer models + +./bin/sherpa-onnx-offline-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 + +(2) For Paraformer + +./bin/sherpa-onnx-offline-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)"; + +int32_t main(int32_t argc, char *argv[]) { + sherpa_onnx::ParseOptions po(kUsageMessage); + + sherpa_onnx::OfflineWebsocketServerConfig config; + + // the server will listen on this port + int32_t port = 6006; + + // size of the thread pool for handling network connections + int32_t num_io_threads = 1; + + // size of the thread pool for neural network computation and decoding + int32_t num_work_threads = 3; + + po.Register("num-io-threads", &num_io_threads, + "Thread pool size for network connections."); + + po.Register("num-work-threads", &num_work_threads, + "Thread pool size for for neural network " + "computation and decoding."); + + po.Register("port", &port, "The port on which the server will listen."); + + config.Register(&po); + + if (argc == 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + po.Read(argc, argv); + + if (po.NumArgs() != 0) { + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + config.Validate(); + + asio::io_context io_conn; // for network connections + asio::io_context io_work; // for neural network and decoding + + sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config); + server.Run(port); + + SHERPA_ONNX_LOGE("Started!"); + SHERPA_ONNX_LOGE("Listening on: %d", port); + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); + + // give some work to do for the io_work pool + auto work_guard = asio::make_work_guard(io_work); + + std::vector io_threads; + + // decrement since the main thread is also used for network communications + for (int32_t i = 0; i < num_io_threads - 1; ++i) { + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); + } + + std::vector work_threads; + for (int32_t i = 0; i < num_work_threads; ++i) { + work_threads.emplace_back([&io_work]() { io_work.run(); }); + } + + io_conn.run(); + + for (auto &t : io_threads) { + t.join(); + } + + for (auto &t : work_threads) { + t.join(); + } + + return 0; +} diff --git a/sherpa-onnx/csrc/online-websocket-server.cc b/sherpa-onnx/csrc/online-websocket-server.cc index 274f0344..6ba7a198 100644 --- a/sherpa-onnx/csrc/online-websocket-server.cc +++ b/sherpa-onnx/csrc/online-websocket-server.cc @@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) { sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); server.Run(port); + SHERPA_ONNX_LOGE("Started!"); SHERPA_ONNX_LOGE("Listening on: %d", port); SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);