This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/offline-websocket-server-impl.h
2023-03-29 21:48:45 +08:00

206 lines
6.7 KiB
C++

// sherpa-onnx/csrc/offline-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
/** Communication protocol
*
* The client sends a byte stream to the server. The first 4 bytes in little
* endian indicates the sample rate of the audio data that the client will send.
* The next 4 bytes in little endian indicates the total samples in bytes the
* client will send. The remaining bytes represent audio samples. Each audio
* sample is a float occupying 4 bytes and is normalized into the range
* [-1, 1].
*
* The byte stream can be broken into arbitrary number of messages.
* We require that the first message has to be at least 8 bytes so that
* we can get `sample_rate` and `expected_byte_size` from the first message.
*/
struct ConnectionData {
// Sample rate of the audio samples the client
int32_t sample_rate;
// Number of expected bytes sent from the client
int32_t expected_byte_size = 0;
// Number of bytes received so far
int32_t cur = 0;
// It saves the received samples from the client.
// We will **reinterpret_cast** it to float.
// We expect that data.size() == expected_byte_size
std::vector<int8_t> data;
void Clear() {
sample_rate = 0;
expected_byte_size = 0;
cur = 0;
data.clear();
}
};
using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
struct OfflineWebsocketDecoderConfig {
OfflineRecognizerConfig recognizer_config;
int32_t max_batch_size = 5;
float max_utterance_length = 300; // seconds
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer;
class OfflineWebsocketDecoder {
public:
/**
* @param config Configuration for the decoder.
* @param server **Borrowed** from outside.
*/
explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
/** Insert received data to the queue for decoding.
*
* @param hdl A handle to the connection. We can use it to send the result
* back to the client once it finishes decoding.
* @param d The received data
*/
void Push(connection_hdl hdl, ConnectionDataPtr d);
/** It is called by one of the work thread.
*/
void Decode();
const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
private:
OfflineWebsocketDecoderConfig config_;
/** When we have received all the data from the client, we put it into
* this queue; the worker threads will get items from this queue for
* decoding.
*
* Number of items to take from this queue is determined by
* `--max-batch-size`. If there are not enough items in the queue, we won't
* wait and take whatever we have for decoding.
*/
std::mutex mutex_;
std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
OfflineWebsocketServer *server_; // Not owned
OfflineRecognizer recognizer_;
};
struct OfflineWebsocketServerConfig {
OfflineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer {
public:
OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config);
asio::io_context &GetConnectionContext() { return io_conn_; }
server &GetServer() { return server_; }
void Run(uint16_t port);
const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
// When a message is received from a websocket client, this method will
// be invoked.
//
// The protocol between the client and the server is as follows:
//
// (1) The client connects to the server
// (2) The client starts to send binary byte stream to the server.
// The byte stream can be broken into multiple messages or it can
// be put into a single message.
// The first message has to contain at least 8 bytes. The first
// 4 bytes in little endian contains a int32_t indicating the
// sampling rate. The next 4 bytes in little endian contains a int32_t
// indicating total number of bytes of samples the client will send.
// We assume each sample is a float containing 4 bytes and has been
// normalized to the range [-1, 1].
// (4) When the server receives all the samples from the client, it will
// start to decode them. Once decoded, the server sends a text message
// to the client containing the decoded results
// (5) After receiving the decoded results from the server, if the client has
// another audio file to send, it repeats (2), (3), (4)
// (6) If the client has no more audio files to decode, the client sends a
// text message containing "Done" to the server and closes the connection
// (7) The server receives a text message "Done" and closes the connection
//
// Note:
// (a) All models in icefall use features extracted from audio samples
// normalized to the range [-1, 1]. Please send normalized audio samples
// if you use models from icefall.
// (b) Only sound files with a single channel is supported
// (c) Only audio samples are sent. For instance, if we want to decode
// a WAVE file, the RIFF header of the WAVE is not sent.
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
connections_;
std::mutex mutex_;
OfflineWebsocketServerConfig config_;
std::ofstream log_;
TeeStream tee_;
OfflineWebsocketDecoder decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_