add offline websocket server/client (#98)

This commit is contained in:
Fangjun Kuang
2023-03-29 21:48:45 +08:00
committed by GitHub
parent 5e5620ea23
commit 6707ec4124
15 changed files with 1032 additions and 59 deletions

View File

@@ -1,7 +1,7 @@
[flake8] [flake8]
show-source=true show-source=true
statistics=true statistics=true
max-line-length = 80 max-line-length = 120
exclude = exclude =
.git, .git,

View File

@@ -30,9 +30,12 @@ ls -lh
ls -lh $repo ls -lh $repo
python3 ./python-api-examples/decode-file.py \ python3 ./python-api-examples/online-decode-files.py \
--tokens=$repo/tokens.txt \ --tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \ --encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \ --decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \ --joiner=$repo/joiner-epoch-99-avg-1.onnx \
--wave-filename=$repo/test_wavs/4.wav $repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav

1
.gitignore vendored
View File

@@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example
run-sherpa-onnx-offline-paraformer.sh run-sherpa-onnx-offline-paraformer.sh
run-sherpa-onnx-offline-transducer.sh run-sherpa-onnx-offline-transducer.sh
sherpa-onnx-paraformer-zh-2023-03-28 sherpa-onnx-paraformer-zh-2023-03-28
run-offline-websocket-server-paraformer.sh

View File

@@ -0,0 +1,158 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation
"""
A websocket client for sherpa-onnx-offline-websocket-server
This file shows how to transcribe multiple
files in parallel. We create a separate connection for transcribing each file.
Usage:
./offline-websocket-client-decode-files-parallel.py \
--server-addr localhost \
--server-port 6006 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/16kHz.wav \
/path/to/8kHz.wav
(Note: You have to first start the server before starting the client)
You can find the server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
Note: The server is implemented in C++.
"""
import argparse
import asyncio
import logging
import wave
from typing import Tuple
try:
import websockets
except ImportError:
print("please run:")
print("")
print(" pip install websockets")
print("")
print("before you run this script")
print("")
import numpy as np
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=6006,
help="Port of the server",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
async def run(
server_addr: str,
server_port: int,
wave_filename: str,
):
async with websockets.connect(
f"ws://{server_addr}:{server_port}"
) as websocket: # noqa
logging.info(f"Sending {wave_filename}")
samples, sample_rate = read_wave(wave_filename)
assert isinstance(sample_rate, int)
assert samples.dtype == np.float32, samples.dtype
assert samples.ndim == 1, samples.dim
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
await websocket.send(buf)
decoding_results = await websocket.recv()
logging.info(f"{wave_filename}\n{decoding_results}")
# to signal that the client has sent all the data
await websocket.send("Done")
async def main():
args = get_args()
logging.info(vars(args))
server_addr = args.server_addr
server_port = args.server_port
sound_files = args.sound_files
all_tasks = []
for wave_filename in sound_files:
task = asyncio.create_task(
run(
server_addr=server_addr,
server_port=server_port,
wave_filename=wave_filename,
)
)
all_tasks.append(task)
await asyncio.gather(*all_tasks)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
)
logging.basicConfig(format=formatter, level=logging.INFO)
asyncio.run(main())

View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation
"""
A websocket client for sherpa-onnx-offline-websocket-server
This file shows how to use a single connection to transcribe multiple
files sequentially.
Usage:
./offline-websocket-client-decode-files-sequential.py \
--server-addr localhost \
--server-port 6006 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/16kHz.wav \
/path/to/8kHz.wav
(Note: You have to first start the server before starting the client)
You can find the server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
Note: The server is implemented in C++.
"""
import argparse
import asyncio
import logging
import wave
from typing import List, Tuple
try:
import websockets
except ImportError:
print("please run:")
print("")
print(" pip install websockets")
print("")
print("before you run this script")
print("")
import numpy as np
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=6006,
help="Port of the server",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
async def run(
server_addr: str,
server_port: int,
sound_files: List[str],
):
async with websockets.connect(
f"ws://{server_addr}:{server_port}"
) as websocket: # noqa
for wave_filename in sound_files:
logging.info(f"Sending {wave_filename}")
samples, sample_rate = read_wave(wave_filename)
assert isinstance(sample_rate, int)
assert samples.dtype == np.float32, samples.dtype
assert samples.ndim == 1, samples.dim
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
await websocket.send(buf)
decoding_results = await websocket.recv()
print(decoding_results)
# to signal that the client has sent all the data
await websocket.send("Done")
async def main():
args = get_args()
logging.info(vars(args))
server_addr = args.server_addr
server_port = args.server_port
sound_files = args.sound_files
await run(
server_addr=server_addr,
server_port=server_port,
sound_files=sound_files,
)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
)
logging.basicConfig(format=formatter, level=logging.INFO)
asyncio.run(main())

View File

@@ -1,8 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This file demonstrates how to use sherpa-onnx Python API to recognize This file demonstrates how to use sherpa-onnx Python API to transcribe
a single file. file(s) with a streaming model.
Usage:
./online-decode-files.py \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/16kHz.wav \
/path/to/8kHz.wav
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html https://k2-fsa.github.io/sherpa/onnx/index.html
@@ -13,17 +20,12 @@ import argparse
import time import time
import wave import wave
from pathlib import Path from pathlib import Path
from typing import Tuple
import numpy as np import numpy as np
import sherpa_onnx import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(
filename
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -68,26 +70,58 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--wave-filename", "sound_files",
type=str, type=str,
help="""Path to the wave filename. nargs="+",
Should have a single channel with 16-bit samples. help="The input sound file(s) to decode. Each file must be of WAVE"
It does not need to be 16kHz. It can have any sampling rate. "format with a single channel, and each sample has 16-bit, "
""", "i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
) )
return parser.parse_args() return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def main(): def main():
args = get_args() args = get_args()
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
assert_file_exists(args.decoder) assert_file_exists(args.decoder)
assert_file_exists(args.joiner) assert_file_exists(args.joiner)
assert_file_exists(args.tokens) assert_file_exists(args.tokens)
if not Path(args.wave_filename).is_file():
print(f"{args.wave_filename} does not exist!")
return
recognizer = sherpa_onnx.OnlineRecognizer( recognizer = sherpa_onnx.OnlineRecognizer(
tokens=args.tokens, tokens=args.tokens,
@@ -99,42 +133,44 @@ def main():
feature_dim=80, feature_dim=80,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
) )
with wave.open(args.wave_filename) as f:
# If the wave file has a different sampling rate from the one
# expected by the model (16 kHz in our case), we will do
# resampling inside sherpa-onnx
wave_file_sample_rate = f.getframerate()
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
duration = len(samples_float32) / wave_file_sample_rate
start_time = time.time()
print("Started!") print("Started!")
start_time = time.time()
stream = recognizer.create_stream() streams = []
total_duration = 0
for wave_filename in args.sound_files:
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
stream.accept_waveform(wave_file_sample_rate, samples_float32) s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32) tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
stream.accept_waveform(wave_file_sample_rate, tail_paddings) s.accept_waveform(sample_rate, tail_paddings)
stream.input_finished() s.input_finished()
while recognizer.is_ready(stream): streams.append(s)
recognizer.decode_stream(stream)
print(recognizer.get_result(stream)) while True:
ready_list = []
print("Done!") for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
end_time = time.time() end_time = time.time()
print("Done!")
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
elapsed_seconds = end_time - start_time elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration rtf = elapsed_seconds / duration
print(f"num_threads: {args.num_threads}") print(f"num_threads: {args.num_threads}")

View File

@@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
import argparse import argparse
import asyncio import asyncio
import logging import logging
import time
import wave import wave
try: try:

View File

@@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
import argparse import argparse
import asyncio import asyncio
import sys import sys
import time
import numpy as np import numpy as np
try: try:
import sounddevice as sd import sounddevice as sd
except ImportError as e: except ImportError:
print("Please install sounddevice first. You can use") print("Please install sounddevice first. You can use")
print() print()
print(" pip install sounddevice") print(" pip install sounddevice")
@@ -134,7 +133,7 @@ async def run(
await websocket.send(indata.tobytes()) await websocket.send(indata.tobytes())
decoding_results = await receive_task decoding_results = await receive_task
print("\nFinal result is:\n{decoding_results}") print(f"\nFinal result is:\n{decoding_results}")
async def main(): async def main():

View File

@@ -13,7 +13,7 @@ from pathlib import Path
try: try:
import sounddevice as sd import sounddevice as sd
except ImportError as e: except ImportError:
print("Please install sounddevice first. You can use") print("Please install sounddevice first. You can use")
print() print()
print(" pip install sounddevice") print(" pip install sounddevice")
@@ -25,9 +25,11 @@ import sherpa_onnx
def assert_file_exists(filename: str): def assert_file_exists(filename: str):
assert Path( assert Path(filename).is_file(), (
filename f"{filename} does not exist!\n"
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" "Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def get_args(): def get_args():

View File

@@ -12,7 +12,7 @@ from pathlib import Path
try: try:
import sounddevice as sd import sounddevice as sd
except ImportError as e: except ImportError:
print("Please install sounddevice first. You can use") print("Please install sounddevice first. You can use")
print() print()
print(" pip install sounddevice") print(" pip install sounddevice")
@@ -24,9 +24,11 @@ import sherpa_onnx
def assert_file_exists(filename: str): def assert_file_exists(filename: str):
assert Path( assert Path(filename).is_file(), (
filename f"{filename} does not exist!\n"
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" "Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def get_args(): def get_args():

View File

@@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
) )
target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
add_executable(sherpa-onnx-online-websocket-client add_executable(sherpa-onnx-online-websocket-client
online-websocket-client.cc online-websocket-client.cc
) )
@@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
endif() endif()
# For offline websocket
add_executable(sherpa-onnx-offline-websocket-server
offline-websocket-server-impl.cc
offline-websocket-server.cc
)
target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx-offline-websocket-server -pthread)
target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations)
endif()
endif() endif()

View File

@@ -0,0 +1,285 @@
// sherpa-onnx/csrc/offline-websocket-server-impl.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
#include <algorithm>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) {
recognizer_config.Register(po);
po->Register("max-batch-size", &max_batch_size,
"Max batch size for decoding.");
po->Register(
"max-utterance-length", &max_utterance_length,
"Max utterance length in seconds. If we receive an utterance "
"longer than this value, we will reject the connection. "
"If you have enough memory, you can select a large value for it.");
}
void OfflineWebsocketDecoderConfig::Validate() const {
if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Error in recongizer config");
exit(-1);
}
if (max_batch_size <= 0) {
SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size);
exit(-1);
}
if (max_utterance_length <= 0) {
SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f",
max_utterance_length);
exit(-1);
}
}
OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server)
: config_(server->GetConfig().decoder_config),
server_(server),
recognizer_(config_.recognizer_config) {}
void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) {
std::lock_guard<std::mutex> lock(mutex_);
streams_.push_back({hdl, d});
}
void OfflineWebsocketDecoder::Decode() {
std::unique_lock<std::mutex> lock(mutex_);
if (streams_.empty()) {
return;
}
int32_t size =
std::min(static_cast<int32_t>(streams_.size()), config_.max_batch_size);
SHERPA_ONNX_LOGE("size: %d", size);
// We first lock the mutex for streams_, take items from it, and then
// unlock the mutex; in doing so we don't need to lock the mutex to
// access hdl and connection_data later.
std::vector<connection_hdl> handles(size);
// Store connection_data here to prevent the data from being freed
// while we are still using it.
std::vector<ConnectionDataPtr> connection_data(size);
std::vector<const float *> samples(size);
std::vector<int32_t> samples_length(size);
std::vector<std::unique_ptr<OfflineStream>> ss(size);
std::vector<OfflineStream *> p_ss(size);
for (int32_t i = 0; i != size; ++i) {
auto &p = streams_.front();
handles[i] = p.first;
connection_data[i] = p.second;
streams_.pop_front();
auto sample_rate = connection_data[i]->sample_rate;
auto samples =
reinterpret_cast<const float *>(&connection_data[i]->data[0]);
auto num_samples = connection_data[i]->expected_byte_size / sizeof(float);
auto s = recognizer_.CreateStream();
s->AcceptWaveform(sample_rate, samples, num_samples);
ss[i] = std::move(s);
p_ss[i] = ss[i].get();
}
lock.unlock();
// Note: DecodeStreams is thread-safe
recognizer_.DecodeStreams(p_ss.data(), size);
for (int32_t i = 0; i != size; ++i) {
connection_hdl hdl = handles[i];
asio::post(server_->GetConnectionContext(),
[this, hdl, text = ss[i]->GetResult().text]() {
websocketpp::lib::error_code ec;
server_->GetServer().send(
hdl, text, websocketpp::frame::opcode::text, ec);
if (ec) {
server_->GetServer().get_alog().write(
websocketpp::log::alevel::app, ec.message());
}
});
}
}
void OfflineWebsocketServerConfig::Register(ParseOptions *po) {
decoder_config.Register(po);
po->Register("log-file", &log_file,
"Path to the log file. Logs are "
"appended to this file");
}
void OfflineWebsocketServerConfig::Validate() const {
decoder_config.Validate();
}
OfflineWebsocketServer::OfflineWebsocketServer(
asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config)
: io_conn_(io_conn),
io_work_(io_work),
config_(config),
log_(config.log_file, std::ios::app),
tee_(std::cout, log_),
decoder_(this) {
SetupLog();
server_.init_asio(&io_conn_);
server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
server_.set_message_handler(
[this](connection_hdl hdl, server::message_ptr msg) {
OnMessage(hdl, msg);
});
}
void OfflineWebsocketServer::SetupLog() {
server_.clear_access_channels(websocketpp::log::alevel::all);
server_.set_access_channels(websocketpp::log::alevel::connect);
server_.set_access_channels(websocketpp::log::alevel::disconnect);
// So that it also prints to std::cout and std::cerr
server_.get_alog().set_ostream(&tee_);
server_.get_elog().set_ostream(&tee_);
}
void OfflineWebsocketServer::OnOpen(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.emplace(hdl, std::make_shared<ConnectionData>());
SHERPA_ONNX_LOGE("Number of active connections: %d",
static_cast<int32_t>(connections_.size()));
}
void OfflineWebsocketServer::OnClose(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(hdl);
SHERPA_ONNX_LOGE("Number of active connections: %d",
static_cast<int32_t>(connections_.size()));
}
void OfflineWebsocketServer::OnMessage(connection_hdl hdl,
server::message_ptr msg) {
std::unique_lock<std::mutex> lock(mutex_);
auto connection_data = connections_.find(hdl)->second;
lock.unlock();
const std::string &payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
if (payload == "Done") {
// The client will not send any more data. We can close the
// connection now.
Close(hdl, websocketpp::close::status::normal, "Done");
} else {
Close(hdl, websocketpp::close::status::normal,
std::string("Invalid payload: ") + payload);
}
break;
case websocketpp::frame::opcode::binary: {
auto p = reinterpret_cast<const int8_t *>(payload.data());
if (connection_data->expected_byte_size == 0) {
if (payload.size() < 8) {
Close(hdl, websocketpp::close::status::normal,
"Payload is too short");
break;
}
connection_data->sample_rate = *reinterpret_cast<const int32_t *>(p);
connection_data->expected_byte_size =
*reinterpret_cast<const int32_t *>(p + 4);
int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length *
connection_data->sample_rate * sizeof(float);
if (connection_data->expected_byte_size > max_byte_size_) {
float num_samples =
connection_data->expected_byte_size / sizeof(float);
float duration = num_samples / connection_data->sample_rate;
std::ostringstream os;
os << "Max utterance length is configured to "
<< decoder_.GetConfig().max_utterance_length
<< " seconds, received length is " << duration << " seconds. "
<< "Payload is too large!";
Close(hdl, websocketpp::close::status::message_too_big, os.str());
break;
}
connection_data->data.resize(connection_data->expected_byte_size);
std::copy(payload.begin() + 8, payload.end(),
connection_data->data.data());
connection_data->cur = payload.size() - 8;
} else {
std::copy(payload.begin(), payload.end(),
connection_data->data.data() + connection_data->cur);
connection_data->cur += payload.size();
}
if (connection_data->expected_byte_size == connection_data->cur) {
auto d = std::make_shared<ConnectionData>(std::move(*connection_data));
// Clear it so that we can handle the next audio file from the client.
// The client can send multiple audio files for recognition without
// the need to create another connection.
connection_data->sample_rate = 0;
connection_data->expected_byte_size = 0;
connection_data->cur = 0;
decoder_.Push(hdl, d);
connection_data->Clear();
asio::post(io_work_, [this]() { decoder_.Decode(); });
}
break;
}
default:
// Unexpected message, ignore it
break;
}
}
void OfflineWebsocketServer::Close(connection_hdl hdl,
websocketpp::close::status::value code,
const std::string &reason) {
auto con = server_.get_con_from_hdl(hdl);
std::ostringstream os;
os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
<< "\n";
websocketpp::lib::error_code ec;
server_.close(hdl, code, reason, ec);
if (ec) {
os << "Failed to close" << con->get_remote_endpoint() << ". "
<< ec.message() << "\n";
}
server_.get_alog().write(websocketpp::log::alevel::app, os.str());
}
void OfflineWebsocketServer::Run(uint16_t port) {
server_.set_reuse_addr(true);
server_.listen(asio::ip::tcp::v4(), port);
server_.start_accept();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,205 @@
// sherpa-onnx/csrc/offline-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
/** Communication protocol
*
* The client sends a byte stream to the server. The first 4 bytes in little
* endian indicates the sample rate of the audio data that the client will send.
* The next 4 bytes in little endian indicates the total samples in bytes the
* client will send. The remaining bytes represent audio samples. Each audio
* sample is a float occupying 4 bytes and is normalized into the range
* [-1, 1].
*
* The byte stream can be broken into arbitrary number of messages.
* We require that the first message has to be at least 8 bytes so that
* we can get `sample_rate` and `expected_byte_size` from the first message.
*/
struct ConnectionData {
// Sample rate of the audio samples the client
int32_t sample_rate;
// Number of expected bytes sent from the client
int32_t expected_byte_size = 0;
// Number of bytes received so far
int32_t cur = 0;
// It saves the received samples from the client.
// We will **reinterpret_cast** it to float.
// We expect that data.size() == expected_byte_size
std::vector<int8_t> data;
void Clear() {
sample_rate = 0;
expected_byte_size = 0;
cur = 0;
data.clear();
}
};
using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
struct OfflineWebsocketDecoderConfig {
OfflineRecognizerConfig recognizer_config;
int32_t max_batch_size = 5;
float max_utterance_length = 300; // seconds
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer;
class OfflineWebsocketDecoder {
public:
/**
* @param config Configuration for the decoder.
* @param server **Borrowed** from outside.
*/
explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
/** Insert received data to the queue for decoding.
*
* @param hdl A handle to the connection. We can use it to send the result
* back to the client once it finishes decoding.
* @param d The received data
*/
void Push(connection_hdl hdl, ConnectionDataPtr d);
/** It is called by one of the work thread.
*/
void Decode();
const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
private:
OfflineWebsocketDecoderConfig config_;
/** When we have received all the data from the client, we put it into
* this queue; the worker threads will get items from this queue for
* decoding.
*
* Number of items to take from this queue is determined by
* `--max-batch-size`. If there are not enough items in the queue, we won't
* wait and take whatever we have for decoding.
*/
std::mutex mutex_;
std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
OfflineWebsocketServer *server_; // Not owned
OfflineRecognizer recognizer_;
};
struct OfflineWebsocketServerConfig {
OfflineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer {
public:
OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config);
asio::io_context &GetConnectionContext() { return io_conn_; }
server &GetServer() { return server_; }
void Run(uint16_t port);
const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
// When a message is received from a websocket client, this method will
// be invoked.
//
// The protocol between the client and the server is as follows:
//
// (1) The client connects to the server
// (2) The client starts to send binary byte stream to the server.
// The byte stream can be broken into multiple messages or it can
// be put into a single message.
// The first message has to contain at least 8 bytes. The first
// 4 bytes in little endian contains a int32_t indicating the
// sampling rate. The next 4 bytes in little endian contains a int32_t
// indicating total number of bytes of samples the client will send.
// We assume each sample is a float containing 4 bytes and has been
// normalized to the range [-1, 1].
// (4) When the server receives all the samples from the client, it will
// start to decode them. Once decoded, the server sends a text message
// to the client containing the decoded results
// (5) After receiving the decoded results from the server, if the client has
// another audio file to send, it repeats (2), (3), (4)
// (6) If the client has no more audio files to decode, the client sends a
// text message containing "Done" to the server and closes the connection
// (7) The server receives a text message "Done" and closes the connection
//
// Note:
// (a) All models in icefall use features extracted from audio samples
// normalized to the range [-1, 1]. Please send normalized audio samples
// if you use models from icefall.
// (b) Only sound files with a single channel is supported
// (c) Only audio samples are sent. For instance, if we want to decode
// a WAVE file, the RIFF header of the WAVE is not sent.
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
connections_;
std::mutex mutex_;
OfflineWebsocketServerConfig config_;
std::ofstream log_;
TeeStream tee_;
OfflineWebsocketDecoder decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_

View File

@@ -0,0 +1,120 @@
// sherpa-onnx/csrc/offline-websocket-server.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "asio.hpp"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
#include "sherpa-onnx/csrc/parse-options.h"
static constexpr const char *kUsageMessage = R"(
Automatic speech recognition with sherpa-onnx using websocket.
Usage:
./bin/sherpa-onnx-offline-websocket-server --help
(1) For transducer models
./bin/sherpa-onnx-offline-websocket-server \
--port=6006 \
--num-work-threads=5 \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--log-file=./log.txt \
--max-batch-size=5
(2) For Paraformer
./bin/sherpa-onnx-offline-websocket-server \
--port=6006 \
--num-work-threads=5 \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--log-file=./log.txt \
--max-batch-size=5
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)";
int32_t main(int32_t argc, char *argv[]) {
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineWebsocketServerConfig config;
// the server will listen on this port
int32_t port = 6006;
// size of the thread pool for handling network connections
int32_t num_io_threads = 1;
// size of the thread pool for neural network computation and decoding
int32_t num_work_threads = 3;
po.Register("num-io-threads", &num_io_threads,
"Thread pool size for network connections.");
po.Register("num-work-threads", &num_work_threads,
"Thread pool size for for neural network "
"computation and decoding.");
po.Register("port", &port, "The port on which the server will listen.");
config.Register(&po);
if (argc == 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
po.Read(argc, argv);
if (po.NumArgs() != 0) {
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
po.PrintUsage();
exit(EXIT_FAILURE);
}
config.Validate();
asio::io_context io_conn; // for network connections
asio::io_context io_work; // for neural network and decoding
sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config);
server.Run(port);
SHERPA_ONNX_LOGE("Started!");
SHERPA_ONNX_LOGE("Listening on: %d", port);
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
// give some work to do for the io_work pool
auto work_guard = asio::make_work_guard(io_work);
std::vector<std::thread> io_threads;
// decrement since the main thread is also used for network communications
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
}
std::vector<std::thread> work_threads;
for (int32_t i = 0; i < num_work_threads; ++i) {
work_threads.emplace_back([&io_work]() { io_work.run(); });
}
io_conn.run();
for (auto &t : io_threads) {
t.join();
}
for (auto &t : work_threads) {
t.join();
}
return 0;
}

View File

@@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) {
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
server.Run(port); server.Run(port);
SHERPA_ONNX_LOGE("Started!");
SHERPA_ONNX_LOGE("Listening on: %d", port); SHERPA_ONNX_LOGE("Listening on: %d", port);
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);