add offline websocket server/client (#98)

This commit is contained in:
Fangjun Kuang
2023-03-29 21:48:45 +08:00
committed by GitHub
parent 5e5620ea23
commit 6707ec4124
15 changed files with 1032 additions and 59 deletions

View File

@@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
)
target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
add_executable(sherpa-onnx-online-websocket-client
online-websocket-client.cc
)
@@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
endif()
# For offline websocket
add_executable(sherpa-onnx-offline-websocket-server
offline-websocket-server-impl.cc
offline-websocket-server.cc
)
target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx-offline-websocket-server -pthread)
target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations)
endif()
endif()

View File

@@ -0,0 +1,285 @@
// sherpa-onnx/csrc/offline-websocket-server-impl.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
#include <algorithm>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) {
recognizer_config.Register(po);
po->Register("max-batch-size", &max_batch_size,
"Max batch size for decoding.");
po->Register(
"max-utterance-length", &max_utterance_length,
"Max utterance length in seconds. If we receive an utterance "
"longer than this value, we will reject the connection. "
"If you have enough memory, you can select a large value for it.");
}
void OfflineWebsocketDecoderConfig::Validate() const {
if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Error in recongizer config");
exit(-1);
}
if (max_batch_size <= 0) {
SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size);
exit(-1);
}
if (max_utterance_length <= 0) {
SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f",
max_utterance_length);
exit(-1);
}
}
OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server)
: config_(server->GetConfig().decoder_config),
server_(server),
recognizer_(config_.recognizer_config) {}
void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) {
std::lock_guard<std::mutex> lock(mutex_);
streams_.push_back({hdl, d});
}
void OfflineWebsocketDecoder::Decode() {
std::unique_lock<std::mutex> lock(mutex_);
if (streams_.empty()) {
return;
}
int32_t size =
std::min(static_cast<int32_t>(streams_.size()), config_.max_batch_size);
SHERPA_ONNX_LOGE("size: %d", size);
// We first lock the mutex for streams_, take items from it, and then
// unlock the mutex; in doing so we don't need to lock the mutex to
// access hdl and connection_data later.
std::vector<connection_hdl> handles(size);
// Store connection_data here to prevent the data from being freed
// while we are still using it.
std::vector<ConnectionDataPtr> connection_data(size);
std::vector<const float *> samples(size);
std::vector<int32_t> samples_length(size);
std::vector<std::unique_ptr<OfflineStream>> ss(size);
std::vector<OfflineStream *> p_ss(size);
for (int32_t i = 0; i != size; ++i) {
auto &p = streams_.front();
handles[i] = p.first;
connection_data[i] = p.second;
streams_.pop_front();
auto sample_rate = connection_data[i]->sample_rate;
auto samples =
reinterpret_cast<const float *>(&connection_data[i]->data[0]);
auto num_samples = connection_data[i]->expected_byte_size / sizeof(float);
auto s = recognizer_.CreateStream();
s->AcceptWaveform(sample_rate, samples, num_samples);
ss[i] = std::move(s);
p_ss[i] = ss[i].get();
}
lock.unlock();
// Note: DecodeStreams is thread-safe
recognizer_.DecodeStreams(p_ss.data(), size);
for (int32_t i = 0; i != size; ++i) {
connection_hdl hdl = handles[i];
asio::post(server_->GetConnectionContext(),
[this, hdl, text = ss[i]->GetResult().text]() {
websocketpp::lib::error_code ec;
server_->GetServer().send(
hdl, text, websocketpp::frame::opcode::text, ec);
if (ec) {
server_->GetServer().get_alog().write(
websocketpp::log::alevel::app, ec.message());
}
});
}
}
void OfflineWebsocketServerConfig::Register(ParseOptions *po) {
decoder_config.Register(po);
po->Register("log-file", &log_file,
"Path to the log file. Logs are "
"appended to this file");
}
void OfflineWebsocketServerConfig::Validate() const {
decoder_config.Validate();
}
OfflineWebsocketServer::OfflineWebsocketServer(
asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config)
: io_conn_(io_conn),
io_work_(io_work),
config_(config),
log_(config.log_file, std::ios::app),
tee_(std::cout, log_),
decoder_(this) {
SetupLog();
server_.init_asio(&io_conn_);
server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
server_.set_message_handler(
[this](connection_hdl hdl, server::message_ptr msg) {
OnMessage(hdl, msg);
});
}
void OfflineWebsocketServer::SetupLog() {
server_.clear_access_channels(websocketpp::log::alevel::all);
server_.set_access_channels(websocketpp::log::alevel::connect);
server_.set_access_channels(websocketpp::log::alevel::disconnect);
// So that it also prints to std::cout and std::cerr
server_.get_alog().set_ostream(&tee_);
server_.get_elog().set_ostream(&tee_);
}
void OfflineWebsocketServer::OnOpen(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.emplace(hdl, std::make_shared<ConnectionData>());
SHERPA_ONNX_LOGE("Number of active connections: %d",
static_cast<int32_t>(connections_.size()));
}
void OfflineWebsocketServer::OnClose(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(hdl);
SHERPA_ONNX_LOGE("Number of active connections: %d",
static_cast<int32_t>(connections_.size()));
}
void OfflineWebsocketServer::OnMessage(connection_hdl hdl,
server::message_ptr msg) {
std::unique_lock<std::mutex> lock(mutex_);
auto connection_data = connections_.find(hdl)->second;
lock.unlock();
const std::string &payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
if (payload == "Done") {
// The client will not send any more data. We can close the
// connection now.
Close(hdl, websocketpp::close::status::normal, "Done");
} else {
Close(hdl, websocketpp::close::status::normal,
std::string("Invalid payload: ") + payload);
}
break;
case websocketpp::frame::opcode::binary: {
auto p = reinterpret_cast<const int8_t *>(payload.data());
if (connection_data->expected_byte_size == 0) {
if (payload.size() < 8) {
Close(hdl, websocketpp::close::status::normal,
"Payload is too short");
break;
}
connection_data->sample_rate = *reinterpret_cast<const int32_t *>(p);
connection_data->expected_byte_size =
*reinterpret_cast<const int32_t *>(p + 4);
int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length *
connection_data->sample_rate * sizeof(float);
if (connection_data->expected_byte_size > max_byte_size_) {
float num_samples =
connection_data->expected_byte_size / sizeof(float);
float duration = num_samples / connection_data->sample_rate;
std::ostringstream os;
os << "Max utterance length is configured to "
<< decoder_.GetConfig().max_utterance_length
<< " seconds, received length is " << duration << " seconds. "
<< "Payload is too large!";
Close(hdl, websocketpp::close::status::message_too_big, os.str());
break;
}
connection_data->data.resize(connection_data->expected_byte_size);
std::copy(payload.begin() + 8, payload.end(),
connection_data->data.data());
connection_data->cur = payload.size() - 8;
} else {
std::copy(payload.begin(), payload.end(),
connection_data->data.data() + connection_data->cur);
connection_data->cur += payload.size();
}
if (connection_data->expected_byte_size == connection_data->cur) {
auto d = std::make_shared<ConnectionData>(std::move(*connection_data));
// Clear it so that we can handle the next audio file from the client.
// The client can send multiple audio files for recognition without
// the need to create another connection.
connection_data->sample_rate = 0;
connection_data->expected_byte_size = 0;
connection_data->cur = 0;
decoder_.Push(hdl, d);
connection_data->Clear();
asio::post(io_work_, [this]() { decoder_.Decode(); });
}
break;
}
default:
// Unexpected message, ignore it
break;
}
}
void OfflineWebsocketServer::Close(connection_hdl hdl,
websocketpp::close::status::value code,
const std::string &reason) {
auto con = server_.get_con_from_hdl(hdl);
std::ostringstream os;
os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
<< "\n";
websocketpp::lib::error_code ec;
server_.close(hdl, code, reason, ec);
if (ec) {
os << "Failed to close" << con->get_remote_endpoint() << ". "
<< ec.message() << "\n";
}
server_.get_alog().write(websocketpp::log::alevel::app, os.str());
}
void OfflineWebsocketServer::Run(uint16_t port) {
server_.set_reuse_addr(true);
server_.listen(asio::ip::tcp::v4(), port);
server_.start_accept();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,205 @@
// sherpa-onnx/csrc/offline-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
/** Communication protocol
*
* The client sends a byte stream to the server. The first 4 bytes in little
* endian indicates the sample rate of the audio data that the client will send.
* The next 4 bytes in little endian indicates the total samples in bytes the
* client will send. The remaining bytes represent audio samples. Each audio
* sample is a float occupying 4 bytes and is normalized into the range
* [-1, 1].
*
* The byte stream can be broken into arbitrary number of messages.
* We require that the first message has to be at least 8 bytes so that
* we can get `sample_rate` and `expected_byte_size` from the first message.
*/
struct ConnectionData {
// Sample rate of the audio samples the client
int32_t sample_rate;
// Number of expected bytes sent from the client
int32_t expected_byte_size = 0;
// Number of bytes received so far
int32_t cur = 0;
// It saves the received samples from the client.
// We will **reinterpret_cast** it to float.
// We expect that data.size() == expected_byte_size
std::vector<int8_t> data;
void Clear() {
sample_rate = 0;
expected_byte_size = 0;
cur = 0;
data.clear();
}
};
using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
struct OfflineWebsocketDecoderConfig {
OfflineRecognizerConfig recognizer_config;
int32_t max_batch_size = 5;
float max_utterance_length = 300; // seconds
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer;
class OfflineWebsocketDecoder {
public:
/**
* @param config Configuration for the decoder.
* @param server **Borrowed** from outside.
*/
explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
/** Insert received data to the queue for decoding.
*
* @param hdl A handle to the connection. We can use it to send the result
* back to the client once it finishes decoding.
* @param d The received data
*/
void Push(connection_hdl hdl, ConnectionDataPtr d);
/** It is called by one of the work thread.
*/
void Decode();
const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
private:
OfflineWebsocketDecoderConfig config_;
/** When we have received all the data from the client, we put it into
* this queue; the worker threads will get items from this queue for
* decoding.
*
* Number of items to take from this queue is determined by
* `--max-batch-size`. If there are not enough items in the queue, we won't
* wait and take whatever we have for decoding.
*/
std::mutex mutex_;
std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
OfflineWebsocketServer *server_; // Not owned
OfflineRecognizer recognizer_;
};
struct OfflineWebsocketServerConfig {
OfflineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer {
public:
OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config);
asio::io_context &GetConnectionContext() { return io_conn_; }
server &GetServer() { return server_; }
void Run(uint16_t port);
const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
// When a message is received from a websocket client, this method will
// be invoked.
//
// The protocol between the client and the server is as follows:
//
// (1) The client connects to the server
// (2) The client starts to send binary byte stream to the server.
// The byte stream can be broken into multiple messages or it can
// be put into a single message.
// The first message has to contain at least 8 bytes. The first
// 4 bytes in little endian contains a int32_t indicating the
// sampling rate. The next 4 bytes in little endian contains a int32_t
// indicating total number of bytes of samples the client will send.
// We assume each sample is a float containing 4 bytes and has been
// normalized to the range [-1, 1].
// (4) When the server receives all the samples from the client, it will
// start to decode them. Once decoded, the server sends a text message
// to the client containing the decoded results
// (5) After receiving the decoded results from the server, if the client has
// another audio file to send, it repeats (2), (3), (4)
// (6) If the client has no more audio files to decode, the client sends a
// text message containing "Done" to the server and closes the connection
// (7) The server receives a text message "Done" and closes the connection
//
// Note:
// (a) All models in icefall use features extracted from audio samples
// normalized to the range [-1, 1]. Please send normalized audio samples
// if you use models from icefall.
// (b) Only sound files with a single channel is supported
// (c) Only audio samples are sent. For instance, if we want to decode
// a WAVE file, the RIFF header of the WAVE is not sent.
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
connections_;
std::mutex mutex_;
OfflineWebsocketServerConfig config_;
std::ofstream log_;
TeeStream tee_;
OfflineWebsocketDecoder decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_

View File

@@ -0,0 +1,120 @@
// sherpa-onnx/csrc/offline-websocket-server.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "asio.hpp"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
#include "sherpa-onnx/csrc/parse-options.h"
static constexpr const char *kUsageMessage = R"(
Automatic speech recognition with sherpa-onnx using websocket.
Usage:
./bin/sherpa-onnx-offline-websocket-server --help
(1) For transducer models
./bin/sherpa-onnx-offline-websocket-server \
--port=6006 \
--num-work-threads=5 \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--log-file=./log.txt \
--max-batch-size=5
(2) For Paraformer
./bin/sherpa-onnx-offline-websocket-server \
--port=6006 \
--num-work-threads=5 \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--log-file=./log.txt \
--max-batch-size=5
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)";
int32_t main(int32_t argc, char *argv[]) {
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineWebsocketServerConfig config;
// the server will listen on this port
int32_t port = 6006;
// size of the thread pool for handling network connections
int32_t num_io_threads = 1;
// size of the thread pool for neural network computation and decoding
int32_t num_work_threads = 3;
po.Register("num-io-threads", &num_io_threads,
"Thread pool size for network connections.");
po.Register("num-work-threads", &num_work_threads,
"Thread pool size for for neural network "
"computation and decoding.");
po.Register("port", &port, "The port on which the server will listen.");
config.Register(&po);
if (argc == 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
po.Read(argc, argv);
if (po.NumArgs() != 0) {
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
po.PrintUsage();
exit(EXIT_FAILURE);
}
config.Validate();
asio::io_context io_conn; // for network connections
asio::io_context io_work; // for neural network and decoding
sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config);
server.Run(port);
SHERPA_ONNX_LOGE("Started!");
SHERPA_ONNX_LOGE("Listening on: %d", port);
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
// give some work to do for the io_work pool
auto work_guard = asio::make_work_guard(io_work);
std::vector<std::thread> io_threads;
// decrement since the main thread is also used for network communications
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
}
std::vector<std::thread> work_threads;
for (int32_t i = 0; i < num_work_threads; ++i) {
work_threads.emplace_back([&io_work]() { io_work.run(); });
}
io_conn.run();
for (auto &t : io_threads) {
t.join();
}
for (auto &t : work_threads) {
t.join();
}
return 0;
}

View File

@@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) {
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
server.Run(port);
SHERPA_ONNX_LOGE("Started!");
SHERPA_ONNX_LOGE("Listening on: %d", port);
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);