add streaming websocket server and client (#62)
This commit is contained in:
@@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
|
||||
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
||||
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
@@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
@@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
include(googletest)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
include(websocketpp)
|
||||
include(asio)
|
||||
endif()
|
||||
|
||||
add_subdirectory(sherpa-onnx)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_C_API)
|
||||
|
||||
@@ -40,6 +40,10 @@ cmake \
|
||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_JNI=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_C_API=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \
|
||||
..
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake"
|
||||
-DSHERPA_ONNX_ENABLE_JNI=ON \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DANDROID_ABI="x86_64" \
|
||||
-DSHERPA_ONNX_ENABLE_C_API=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
|
||||
-DANDROID_PLATFORM=android-21 ..
|
||||
|
||||
# make VERBOSE=1 -j4
|
||||
|
||||
39
cmake/asio.cmake
Normal file
39
cmake/asio.cmake
Normal file
@@ -0,0 +1,39 @@
|
||||
function(download_asio)
|
||||
include(FetchContent)
|
||||
|
||||
set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz")
|
||||
set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download asio
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/asio-asio-1-24-0.tar.gz
|
||||
${PROJECT_BINARY_DIR}/asio-asio-1-24-0.tar.gz
|
||||
/tmp/asio-asio-1-24-0.tar.gz
|
||||
/star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(asio_URL "file://${f}")
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(asio
|
||||
URL ${asio_URL}
|
||||
URL_HASH ${asio_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(asio)
|
||||
if(NOT asio_POPULATED)
|
||||
message(STATUS "Downloading asio ${asio_URL}")
|
||||
FetchContent_Populate(asio)
|
||||
endif()
|
||||
message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}")
|
||||
# add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
include_directories(${asio_SOURCE_DIR}/asio/include)
|
||||
endfunction()
|
||||
|
||||
download_asio()
|
||||
40
cmake/websocketpp.cmake
Normal file
40
cmake/websocketpp.cmake
Normal file
@@ -0,0 +1,40 @@
|
||||
function(download_websocketpp)
|
||||
include(FetchContent)
|
||||
|
||||
# The latest commit on the develop branch os as 2022-10-22
|
||||
set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip")
|
||||
set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download websocketpp
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
|
||||
${PROJECT_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
|
||||
${PROJECT_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
|
||||
/tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
|
||||
/star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(websocketpp_URL "file://${f}")
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(websocketpp
|
||||
URL ${websocketpp_URL}
|
||||
URL_HASH ${websocketpp_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(websocketpp)
|
||||
if(NOT websocketpp_POPULATED)
|
||||
message(STATUS "Downloading websocketpp from ${websocketpp_URL}")
|
||||
FetchContent_Populate(websocketpp)
|
||||
endif()
|
||||
message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}")
|
||||
# add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
include_directories(${websocketpp_SOURCE_DIR})
|
||||
endfunction()
|
||||
|
||||
download_websocketpp()
|
||||
@@ -4,6 +4,7 @@ set(sources
|
||||
cat.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
file-utils.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
@@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO)
|
||||
install(TARGETS sherpa-onnx-microphone DESTINATION bin)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
add_definitions(-DASIO_STANDALONE)
|
||||
add_definitions(-D_WEBSOCKETPP_CPP11_STL_)
|
||||
|
||||
add_executable(sherpa-onnx-online-websocket-server
|
||||
online-websocket-server-impl.cc
|
||||
online-websocket-server.cc
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
|
||||
|
||||
|
||||
add_executable(sherpa-onnx-online-websocket-client
|
||||
online-websocket-client.cc
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-online-websocket-client sherpa-onnx-core)
|
||||
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx-online-websocket-server -pthread)
|
||||
target_compile_options(sherpa-onnx-online-websocket-server PRIVATE -Wno-deprecated-declarations)
|
||||
|
||||
target_link_libraries(sherpa-onnx-online-websocket-client -pthread)
|
||||
target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
set(sherpa_onnx_test_srcs
|
||||
|
||||
1
sherpa-onnx/csrc/CPPLINT.cfg
Normal file
1
sherpa-onnx/csrc/CPPLINT.cfg
Normal file
@@ -0,0 +1 @@
|
||||
exclude_files=tee-stream.h
|
||||
@@ -14,6 +14,15 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. Must match the one "
|
||||
"expected by the model.");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
}
|
||||
|
||||
std::string FeatureExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct FeatureExtractorConfig {
|
||||
@@ -16,6 +18,8 @@ struct FeatureExtractorConfig {
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
};
|
||||
|
||||
class FeatureExtractor {
|
||||
|
||||
24
sherpa-onnx/csrc/file-utils.cc
Normal file
24
sherpa-onnx/csrc/file-utils.cc
Normal file
@@ -0,0 +1,24 @@
|
||||
// sherpa-onnx/csrc/file-utils.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
bool FileExists(const std::string &filename) {
|
||||
return std::ifstream(filename).good();
|
||||
}
|
||||
|
||||
void AssertFileExists(const std::string &filename) {
|
||||
if (!FileExists(filename)) {
|
||||
SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/file-utils.h
Normal file
28
sherpa-onnx/csrc/file-utils.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/file-utils.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_
|
||||
#define SHERPA_ONNX_CSRC_FILE_UTILS_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Check whether a given path is a file or not
|
||||
*
|
||||
* @param filename Path to check.
|
||||
* @return Return true if the given path is a file; return false otherwise.
|
||||
*/
|
||||
bool FileExists(const std::string &filename);
|
||||
|
||||
/** Abort if the file does not exist.
|
||||
*
|
||||
* @param filename The file to check.
|
||||
*/
|
||||
void AssertFileExists(const std::string &filename);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
return ans;
|
||||
}
|
||||
|
||||
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
endpoint_config.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
}
|
||||
|
||||
bool OnlineRecognizerConfig::Validate() const {
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
|
||||
@@ -17,11 +17,15 @@
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineRecognizerResult {
|
||||
std::string text;
|
||||
|
||||
// TODO(fangjun): Add a method to return a json string
|
||||
std::string ToString() const { return text; }
|
||||
};
|
||||
|
||||
struct OnlineRecognizerConfig {
|
||||
@@ -41,6 +45,9 @@ struct OnlineRecognizerConfig {
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -5,8 +5,52 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
|
||||
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
|
||||
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
po->Register("num_threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
}
|
||||
|
||||
bool OnlineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(joiner_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OnlineTransducerModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineTransducerModelConfig {
|
||||
@@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig {
|
||||
std::string decoder_filename;
|
||||
std::string joiner_filename;
|
||||
std::string tokens;
|
||||
int32_t num_threads;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
|
||||
OnlineTransducerModelConfig() = default;
|
||||
@@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig {
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
267
sherpa-onnx/csrc/online-websocket-client.cc
Normal file
267
sherpa-onnx/csrc/online-websocket-client.cc
Normal file
@@ -0,0 +1,267 @@
|
||||
// sherpa/cpp_api/websocket/online-websocket-client.cc
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation
|
||||
#include <chrono> // NOLINT
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
#include "websocketpp/client.hpp"
|
||||
#include "websocketpp/config/asio_no_tls_client.hpp"
|
||||
#include "websocketpp/uri.hpp"
|
||||
|
||||
using client = websocketpp::client<websocketpp::config::asio_client>;
|
||||
|
||||
using message_ptr = client::message_ptr;
|
||||
using websocketpp::connection_hdl;
|
||||
|
||||
static constexpr const char *kUsageMessage = R"(
|
||||
Automatic speech recognition with sherpa-onnx using websocket.
|
||||
|
||||
Usage:
|
||||
|
||||
./bin/sherpa-onnx-online-websocket-client --help
|
||||
|
||||
./bin/sherpa-onnx-online-websocket-client \
|
||||
--server-ip=127.0.0.1 \
|
||||
--server-port=6006 \
|
||||
--samples-per-message=8000 \
|
||||
--seconds-per-message=0.2 \
|
||||
/path/to/foo.wav
|
||||
|
||||
It support only wave of with a single channel, 16kHz, 16-bit samples.
|
||||
)";
|
||||
|
||||
class Client {
|
||||
public:
|
||||
Client(asio::io_context &io, // NOLINT
|
||||
const std::string &ip, int16_t port, const std::vector<float> &samples,
|
||||
int32_t samples_per_message, float seconds_per_message)
|
||||
: io_(io),
|
||||
uri_(/*secure*/ false, ip, port, /*resource*/ "/"),
|
||||
samples_(samples),
|
||||
samples_per_message_(samples_per_message),
|
||||
seconds_per_message_(seconds_per_message) {
|
||||
c_.clear_access_channels(websocketpp::log::alevel::all);
|
||||
// c_.set_access_channels(websocketpp::log::alevel::connect);
|
||||
// c_.set_access_channels(websocketpp::log::alevel::disconnect);
|
||||
|
||||
c_.init_asio(&io_);
|
||||
c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
|
||||
c_.set_close_handler(
|
||||
[this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); });
|
||||
c_.set_message_handler(
|
||||
[this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); });
|
||||
|
||||
Run();
|
||||
}
|
||||
|
||||
private:
|
||||
void Run() {
|
||||
websocketpp::lib::error_code ec;
|
||||
client::connection_ptr con = c_.get_connection(uri_.str(), ec);
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOGE("Could not create connection to %s because %s",
|
||||
uri_.str().c_str(), ec.message().c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
c_.connect(con);
|
||||
}
|
||||
|
||||
void OnOpen(connection_hdl hdl) {
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
asio::post(
|
||||
io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); });
|
||||
}
|
||||
|
||||
void OnMessage(connection_hdl hdl, message_ptr msg) {
|
||||
const std::string &payload = msg->get_payload();
|
||||
|
||||
if (payload == "Done!") {
|
||||
websocketpp::lib::error_code ec;
|
||||
c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec);
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("%s", payload.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void SendMessage(
|
||||
connection_hdl hdl,
|
||||
std::chrono::time_point<std::chrono::steady_clock> start_time) {
|
||||
int32_t num_samples = samples_.size();
|
||||
int32_t num_messages = num_samples / samples_per_message_;
|
||||
|
||||
websocketpp::lib::error_code ec;
|
||||
auto time = std::chrono::steady_clock::now();
|
||||
int elapsed_time_ms =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(time - start_time)
|
||||
.count();
|
||||
|
||||
if (elapsed_time_ms <
|
||||
static_cast<int>(seconds_per_message_ * num_sent_messages_ * 1000)) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(int(
|
||||
seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms)));
|
||||
}
|
||||
|
||||
if (num_sent_messages_ < 1) {
|
||||
SHERPA_ONNX_LOGE("Starting to send audio");
|
||||
}
|
||||
|
||||
if (num_sent_messages_ < num_messages) {
|
||||
c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_,
|
||||
samples_per_message_ * sizeof(float),
|
||||
websocketpp::frame::opcode::binary, ec);
|
||||
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
|
||||
ec.message().c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
ec.clear();
|
||||
|
||||
++num_sent_messages_;
|
||||
}
|
||||
|
||||
if (num_sent_messages_ == num_messages) {
|
||||
int32_t remaining_samples = num_samples % samples_per_message_;
|
||||
if (remaining_samples) {
|
||||
c_.send(hdl,
|
||||
samples_.data() + num_sent_messages_ * samples_per_message_,
|
||||
remaining_samples * sizeof(float),
|
||||
websocketpp::frame::opcode::binary, ec);
|
||||
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
|
||||
ec.message().c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
ec.clear();
|
||||
}
|
||||
|
||||
// To signal that we have send all the messages
|
||||
c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec);
|
||||
SHERPA_ONNX_LOGE("Sent Done Signal");
|
||||
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
|
||||
ec.message().c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else {
|
||||
asio::post(io_, [this, hdl, start_time]() {
|
||||
this->SendMessage(hdl, start_time);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
client c_;
|
||||
asio::io_context &io_;
|
||||
websocketpp::uri uri_;
|
||||
std::vector<float> samples_;
|
||||
int32_t samples_per_message_ = 8000; // 0.5 seconds
|
||||
float seconds_per_message_ = 0.2;
|
||||
int32_t num_sent_messages_ = 0;
|
||||
};
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
std::string server_ip = "127.0.0.1";
|
||||
int32_t server_port = 6006;
|
||||
|
||||
// Sample rate of the input wave. No resampling is made.
|
||||
int32_t sample_rate = 16000;
|
||||
int32_t samples_per_message = 8000;
|
||||
float seconds_per_message = 0.2;
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
|
||||
po.Register("server-ip", &server_ip, "IP address of the websocket server");
|
||||
po.Register("server-port", &server_port, "Port of the websocket server");
|
||||
po.Register("sample-rate", &sample_rate,
|
||||
"Sample rate of the input wave. Should be the one expected by "
|
||||
"the server");
|
||||
|
||||
po.Register("samples-per-message", &samples_per_message,
|
||||
"Send this number of samples per message.");
|
||||
|
||||
po.Register("seconds-per-message", &seconds_per_message,
|
||||
"We will simulate that each message takes this number of seconds "
|
||||
"to send. If you select a very large value, it will take a long "
|
||||
"time to send all the samples");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(),
|
||||
server_ip.end())) {
|
||||
SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (server_port <= 0 || server_port > 65535) {
|
||||
SHERPA_ONNX_LOGE("Invalid server port: %d", server_port);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 0.01 is an arbitrary value. You can change it.
|
||||
if (samples_per_message <= 0.01 * sample_rate) {
|
||||
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
|
||||
samples_per_message);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 100 is an arbitrary value. You can change it.
|
||||
if (samples_per_message >= sample_rate * 100) {
|
||||
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
|
||||
samples_per_message);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (seconds_per_message < 0) {
|
||||
SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f",
|
||||
seconds_per_message);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 1 is an arbitrary value.
|
||||
if (seconds_per_message > 1) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--seconds-per-message is too large: %.3f. You will wait a long time "
|
||||
"to "
|
||||
"send all the samples",
|
||||
seconds_per_message);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (po.NumArgs() != 1) {
|
||||
po.PrintUsage();
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string wave_filename = po.GetArg(1);
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
asio::io_context io_conn; // for network connections
|
||||
Client c(io_conn, server_ip, server_port, samples, samples_per_message,
|
||||
seconds_per_message);
|
||||
|
||||
io_conn.run(); // will exit when the above connection is closed
|
||||
|
||||
SHERPA_ONNX_LOGE("Done!");
|
||||
return 0;
|
||||
}
|
||||
327
sherpa-onnx/csrc/online-websocket-server-impl.cc
Normal file
327
sherpa-onnx/csrc/online-websocket-server-impl.cc
Normal file
@@ -0,0 +1,327 @@
|
||||
// sherpa-onnx/csrc/online-websocket-server-impl.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) {
|
||||
recognizer_config.Register(po);
|
||||
|
||||
po->Register("loop-interval-ms", &loop_interval_ms,
|
||||
"It determines how often the decoder loop runs. ");
|
||||
|
||||
po->Register("max-batch-size", &max_batch_size,
|
||||
"Max batch size for recognition.");
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoderConfig::Validate() const {
|
||||
recognizer_config.Validate();
|
||||
SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0);
|
||||
SHERPA_ONNX_CHECK_GT(max_batch_size, 0);
|
||||
}
|
||||
|
||||
void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) {
|
||||
decoder_config.Register(po);
|
||||
|
||||
po->Register("log-file", &log_file,
|
||||
"Path to the log file. Logs are "
|
||||
"appended to this file");
|
||||
}
|
||||
|
||||
void OnlineWebsocketServerConfig::Validate() const {
|
||||
decoder_config.Validate();
|
||||
}
|
||||
|
||||
OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server)
|
||||
: server_(server),
|
||||
config_(server->GetConfig().decoder_config),
|
||||
timer_(server->GetWorkContext()) {
|
||||
recognizer_ = std::make_unique<OnlineRecognizer>(config_.recognizer_config);
|
||||
}
|
||||
|
||||
std::shared_ptr<Connection> OnlineWebsocketDecoder::GetOrCreateConnection(
|
||||
connection_hdl hdl) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = connections_.find(hdl);
|
||||
if (it != connections_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
// create a new connection
|
||||
std::shared_ptr<OnlineStream> s = recognizer_->CreateStream();
|
||||
auto c = std::make_shared<Connection>(hdl, s);
|
||||
connections_.insert({hdl, c});
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr<Connection> c) {
|
||||
std::lock_guard<std::mutex> lock(c->mutex);
|
||||
float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
|
||||
while (!c->samples.empty()) {
|
||||
const auto &s = c->samples.front();
|
||||
c->s->AcceptWaveform(sample_rate, s.data(), s.size());
|
||||
c->samples.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
|
||||
std::lock_guard<std::mutex> lock(c->mutex);
|
||||
|
||||
float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
|
||||
|
||||
while (!c->samples.empty()) {
|
||||
const auto &s = c->samples.front();
|
||||
c->s->AcceptWaveform(sample_rate, s.data(), s.size());
|
||||
c->samples.pop_front();
|
||||
}
|
||||
|
||||
// TODO(fangjun): Change the amount of paddings to be configurable
|
||||
std::vector<float> tail_padding(static_cast<int64_t>(0.8 * sample_rate));
|
||||
|
||||
c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size());
|
||||
|
||||
c->s->InputFinished();
|
||||
c->eof = true;
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::Run() {
|
||||
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
|
||||
|
||||
timer_.async_wait(
|
||||
[this](const asio::error_code &ec) { ProcessConnections(ec); });
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) {
|
||||
if (ec) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!";
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<connection_hdl> to_remove;
|
||||
for (auto &p : connections_) {
|
||||
auto hdl = p.first;
|
||||
auto c = p.second;
|
||||
|
||||
// The order of `if` below matters!
|
||||
if (!server_->Contains(hdl)) {
|
||||
// If the connection is disconnected, we stop processing it
|
||||
to_remove.push_back(hdl);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (active_.count(hdl)) {
|
||||
// Another thread is decoding this stream, so skip it
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!recognizer_->IsReady(c->s.get()) && !c->eof) {
|
||||
// this stream has not enough frames to decode, so skip it
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!recognizer_->IsReady(c->s.get()) && c->eof) {
|
||||
// We won't receive samples from the client, so send a Done! to client
|
||||
|
||||
asio::post(server_->GetWorkContext(),
|
||||
[this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); });
|
||||
|
||||
to_remove.push_back(hdl);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(fangun): If the connection is timed out, we need to also
|
||||
// add it to `to_remove`
|
||||
|
||||
// this stream has enough frames and is currently not processed by any
|
||||
// threads, so put it into the ready queue
|
||||
ready_connections_.push_back(c);
|
||||
|
||||
// In `Decode()`, it will remove hdl from `active_`
|
||||
active_.insert(c->hdl);
|
||||
}
|
||||
|
||||
for (auto hdl : to_remove) {
|
||||
connections_.erase(hdl);
|
||||
}
|
||||
|
||||
if (!ready_connections_.empty()) {
|
||||
asio::post(server_->GetWorkContext(), [this]() { Decode(); });
|
||||
}
|
||||
|
||||
// Schedule another call
|
||||
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
|
||||
|
||||
timer_.async_wait(
|
||||
[this](const asio::error_code &ec) { ProcessConnections(ec); });
|
||||
}
|
||||
|
||||
void OnlineWebsocketDecoder::Decode() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (ready_connections_.empty()) {
|
||||
// There are no connections that are ready for decoding,
|
||||
// so we return directly
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Connection>> c_vec;
|
||||
std::vector<OnlineStream *> s_vec;
|
||||
while (!ready_connections_.empty() &&
|
||||
static_cast<int32_t>(s_vec.size()) < config_.max_batch_size) {
|
||||
auto c = ready_connections_.front();
|
||||
ready_connections_.pop_front();
|
||||
|
||||
c_vec.push_back(c);
|
||||
s_vec.push_back(c->s.get());
|
||||
}
|
||||
|
||||
if (!ready_connections_.empty()) {
|
||||
// there are too many ready connections but this thread can only handle
|
||||
// max_batch_size connections at a time, so we schedule another call
|
||||
// to Decode() and let other threads to process the ready connections
|
||||
asio::post(server_->GetWorkContext(), [this]() { Decode(); });
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
recognizer_->DecodeStreams(s_vec.data(), s_vec.size());
|
||||
lock.lock();
|
||||
|
||||
for (auto c : c_vec) {
|
||||
auto result = recognizer_->GetResult(c->s.get());
|
||||
|
||||
asio::post(server_->GetConnectionContext(),
|
||||
[this, hdl = c->hdl, str = result.ToString()]() {
|
||||
server_->Send(hdl, str);
|
||||
});
|
||||
active_.erase(c->hdl);
|
||||
}
|
||||
}
|
||||
|
||||
OnlineWebsocketServer::OnlineWebsocketServer(
|
||||
asio::io_context &io_conn, asio::io_context &io_work,
|
||||
const OnlineWebsocketServerConfig &config)
|
||||
: config_(config),
|
||||
io_conn_(io_conn),
|
||||
io_work_(io_work),
|
||||
log_(config.log_file, std::ios::app),
|
||||
tee_(std::cout, log_),
|
||||
decoder_(this) {
|
||||
SetupLog();
|
||||
|
||||
server_.init_asio(&io_conn_);
|
||||
|
||||
server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
|
||||
|
||||
server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
|
||||
|
||||
server_.set_message_handler(
|
||||
[this](connection_hdl hdl, server::message_ptr msg) {
|
||||
OnMessage(hdl, msg);
|
||||
});
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::Run(uint16_t port) {
|
||||
server_.set_reuse_addr(true);
|
||||
server_.listen(asio::ip::tcp::v4(), port);
|
||||
server_.start_accept();
|
||||
decoder_.Run();
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::SetupLog() {
|
||||
server_.clear_access_channels(websocketpp::log::alevel::all);
|
||||
// server_.set_access_channels(websocketpp::log::alevel::connect);
|
||||
// server_.set_access_channels(websocketpp::log::alevel::disconnect);
|
||||
|
||||
// So that it also prints to std::cout and std::cerr
|
||||
server_.get_alog().set_ostream(&tee_);
|
||||
server_.get_elog().set_ostream(&tee_);
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) {
|
||||
websocketpp::lib::error_code ec;
|
||||
if (!Contains(hdl)) {
|
||||
return;
|
||||
}
|
||||
|
||||
server_.send(hdl, text, websocketpp::frame::opcode::text, ec);
|
||||
if (ec) {
|
||||
server_.get_alog().write(websocketpp::log::alevel::app, ec.message());
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::OnOpen(connection_hdl hdl) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connections_.insert(hdl);
|
||||
|
||||
std::ostringstream os;
|
||||
os << "New connection: "
|
||||
<< server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". "
|
||||
<< "Number of active connections: " << connections_.size() << ".\n";
|
||||
SHERPA_ONNX_LOG(INFO) << os.str();
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::OnClose(connection_hdl hdl) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connections_.erase(hdl);
|
||||
|
||||
SHERPA_ONNX_LOG(INFO) << "Number of active connections: "
|
||||
<< connections_.size() << "\n";
|
||||
}
|
||||
|
||||
bool OnlineWebsocketServer::Contains(connection_hdl hdl) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return connections_.count(hdl);
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::OnMessage(connection_hdl hdl,
|
||||
server::message_ptr msg) {
|
||||
auto c = decoder_.GetOrCreateConnection(hdl);
|
||||
|
||||
const std::string &payload = msg->get_payload();
|
||||
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text:
|
||||
if (payload == "Done") {
|
||||
asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); });
|
||||
}
|
||||
break;
|
||||
case websocketpp::frame::opcode::binary: {
|
||||
auto p = reinterpret_cast<const float *>(payload.data());
|
||||
int32_t num_samples = payload.size() / sizeof(float);
|
||||
std::vector<float> samples(p, p + num_samples);
|
||||
|
||||
c->samples.push_back(std::move(samples));
|
||||
|
||||
asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineWebsocketServer::Close(connection_hdl hdl,
|
||||
websocketpp::close::status::value code,
|
||||
const std::string &reason) {
|
||||
auto con = server_.get_con_from_hdl(hdl);
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
|
||||
<< "\n";
|
||||
|
||||
websocketpp::lib::error_code ec;
|
||||
server_.close(hdl, code, reason, ec);
|
||||
if (ec) {
|
||||
os << "Failed to close" << con->get_remote_endpoint() << ". "
|
||||
<< ec.message() << "\n";
|
||||
}
|
||||
server_.get_alog().write(websocketpp::log::alevel::app, os.str());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
177
sherpa-onnx/csrc/online-websocket-server-impl.h
Normal file
177
sherpa-onnx/csrc/online-websocket-server-impl.h
Normal file
@@ -0,0 +1,177 @@
|
||||
// sherpa-onnx/csrc/online-websocket-server-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
|
||||
#include <deque>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "asio.hpp"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/tee-stream.h"
|
||||
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
|
||||
#include "websocketpp/server.hpp"
|
||||
using server = websocketpp::server<websocketpp::config::asio>;
|
||||
using connection_hdl = websocketpp::connection_hdl;
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct Connection {
|
||||
// handle to the connection. We can use it to send messages to the client
|
||||
connection_hdl hdl;
|
||||
std::shared_ptr<OnlineStream> s;
|
||||
|
||||
// set it to true when InputFinished() is called
|
||||
bool eof = false;
|
||||
|
||||
// The last time we received a message from the client
|
||||
// TODO(fangjun): Use it to disconnect from a client if it is inactive
|
||||
// for a specified time.
|
||||
std::chrono::steady_clock::time_point last_active;
|
||||
|
||||
std::mutex mutex; // protect samples
|
||||
|
||||
// Audio samples received from the client.
|
||||
//
|
||||
// The I/O threads receive audio samples into this queue
|
||||
// and invoke work threads to compute features
|
||||
std::deque<std::vector<float>> samples;
|
||||
|
||||
Connection() = default;
|
||||
Connection(connection_hdl hdl, std::shared_ptr<OnlineStream> s)
|
||||
: hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {}
|
||||
};
|
||||
|
||||
struct OnlineWebsocketDecoderConfig {
|
||||
OnlineRecognizerConfig recognizer_config;
|
||||
|
||||
// It determines how often the decoder loop runs.
|
||||
int32_t loop_interval_ms = 10;
|
||||
|
||||
int32_t max_batch_size = 5;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
void Validate() const;
|
||||
};
|
||||
|
||||
class OnlineWebsocketServer;
|
||||
|
||||
class OnlineWebsocketDecoder {
|
||||
public:
|
||||
/**
|
||||
* @param server Not owned.
|
||||
*/
|
||||
explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server);
|
||||
|
||||
std::shared_ptr<Connection> GetOrCreateConnection(connection_hdl hdl);
|
||||
|
||||
// Compute features for a stream given audio samples
|
||||
void AcceptWaveform(std::shared_ptr<Connection> c);
|
||||
|
||||
// signal that there will be no more audio samples for a stream
|
||||
void InputFinished(std::shared_ptr<Connection> c);
|
||||
|
||||
void Run();
|
||||
|
||||
private:
|
||||
void ProcessConnections(const asio::error_code &ec);
|
||||
|
||||
/** It is called by one of the worker thread.
|
||||
*/
|
||||
void Decode();
|
||||
|
||||
private:
|
||||
OnlineWebsocketServer *server_; // not owned
|
||||
std::unique_ptr<OnlineRecognizer> recognizer_;
|
||||
OnlineWebsocketDecoderConfig config_;
|
||||
asio::steady_timer timer_;
|
||||
|
||||
// It protects `connections_`, `ready_connections_`, and `active_`
|
||||
std::mutex mutex_;
|
||||
|
||||
std::map<connection_hdl, std::shared_ptr<Connection>,
|
||||
std::owner_less<connection_hdl>>
|
||||
connections_;
|
||||
|
||||
// Whenever a connection has enough feature frames for decoding, we put
|
||||
// it in this queue
|
||||
std::deque<std::shared_ptr<Connection>> ready_connections_;
|
||||
|
||||
// If we are decoding a stream, we put it in the active_ set so that
|
||||
// only one thread can decode a stream at a time.
|
||||
std::set<connection_hdl, std::owner_less<connection_hdl>> active_;
|
||||
};
|
||||
|
||||
struct OnlineWebsocketServerConfig {
|
||||
OnlineWebsocketDecoderConfig decoder_config;
|
||||
|
||||
std::string log_file = "./log.txt";
|
||||
|
||||
void Register(sherpa_onnx::ParseOptions *po);
|
||||
void Validate() const;
|
||||
};
|
||||
|
||||
class OnlineWebsocketServer {
|
||||
public:
|
||||
explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT
|
||||
asio::io_context &io_work, // NOLINT
|
||||
const OnlineWebsocketServerConfig &config);
|
||||
|
||||
void Run(uint16_t port);
|
||||
|
||||
const OnlineWebsocketServerConfig &GetConfig() const { return config_; }
|
||||
asio::io_context &GetConnectionContext() { return io_conn_; }
|
||||
asio::io_context &GetWorkContext() { return io_work_; }
|
||||
server &GetServer() { return server_; }
|
||||
|
||||
void Send(connection_hdl hdl, const std::string &text);
|
||||
|
||||
bool Contains(connection_hdl hdl) const;
|
||||
|
||||
private:
|
||||
void SetupLog();
|
||||
|
||||
// When a websocket client is connected, it will invoke this method
|
||||
// (Not for HTTP)
|
||||
void OnOpen(connection_hdl hdl);
|
||||
|
||||
// When a websocket client is disconnected, it will invoke this method
|
||||
void OnClose(connection_hdl hdl);
|
||||
|
||||
void OnMessage(connection_hdl hdl, server::message_ptr msg);
|
||||
|
||||
// Close a websocket connection with given code and reason
|
||||
void Close(connection_hdl hdl, websocketpp::close::status::value code,
|
||||
const std::string &reason);
|
||||
|
||||
private:
|
||||
OnlineWebsocketServerConfig config_;
|
||||
asio::io_context &io_conn_;
|
||||
asio::io_context &io_work_;
|
||||
server server_;
|
||||
|
||||
std::ofstream log_;
|
||||
sherpa_onnx::TeeStream tee_;
|
||||
|
||||
OnlineWebsocketDecoder decoder_;
|
||||
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
std::set<connection_hdl, std::owner_less<connection_hdl>> connections_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
108
sherpa-onnx/csrc/online-websocket-server.cc
Normal file
108
sherpa-onnx/csrc/online-websocket-server.cc
Normal file
@@ -0,0 +1,108 @@
|
||||
// sherpa-onnx/csrc/online-websocket-server.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "asio.hpp"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
static constexpr const char *kUsageMessage = R"(
|
||||
Automatic speech recognition with sherpa-onnx using websocket.
|
||||
|
||||
Usage:
|
||||
|
||||
./bin/sherpa-onnx-online-websocket-server --help
|
||||
|
||||
./bin/sherpa-onnx-online-websocket-server \
|
||||
--port=6006 \
|
||||
--num-work-threads=5 \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--log-file=./log.txt \
|
||||
--max-batch-size=5 \
|
||||
--loop-interval-ms=10
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)";
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
|
||||
sherpa_onnx::OnlineWebsocketServerConfig config;
|
||||
|
||||
// the server will listen on this port
|
||||
int32_t port = 6006;
|
||||
|
||||
// size of the thread pool for handling network connections
|
||||
int32_t num_io_threads = 1;
|
||||
|
||||
// size of the thread pool for neural network computation and decoding
|
||||
int32_t num_work_threads = 3;
|
||||
|
||||
po.Register("num-io-threads", &num_io_threads,
|
||||
"Thread pool size for network connections.");
|
||||
|
||||
po.Register("num-work-threads", &num_work_threads,
|
||||
"Thread pool size for for neural network "
|
||||
"computation and decoding.");
|
||||
|
||||
po.Register("port", &port, "The port on which the server will listen.");
|
||||
|
||||
config.Register(&po);
|
||||
|
||||
if (argc == 1) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 0) {
|
||||
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
config.Validate();
|
||||
|
||||
asio::io_context io_conn; // for network connections
|
||||
asio::io_context io_work; // for neural network and decoding
|
||||
|
||||
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
|
||||
server.Run(port);
|
||||
|
||||
SHERPA_ONNX_LOGE("Listening on: %d", port);
|
||||
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
|
||||
|
||||
// give some work to do for the io_work pool
|
||||
auto work_guard = asio::make_work_guard(io_work);
|
||||
|
||||
std::vector<std::thread> io_threads;
|
||||
|
||||
// decrement since the main thread is also used for network communications
|
||||
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
|
||||
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
|
||||
}
|
||||
|
||||
std::vector<std::thread> work_threads;
|
||||
for (int32_t i = 0; i < num_work_threads; ++i) {
|
||||
work_threads.emplace_back([&io_work]() { io_work.run(); });
|
||||
}
|
||||
|
||||
io_conn.run();
|
||||
|
||||
for (auto &t : io_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
for (auto &t : work_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
61
sherpa-onnx/csrc/tee-stream.h
Normal file
61
sherpa-onnx/csrc/tee-stream.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Code in this file is copied and modified from
|
||||
// https://wordaligned.org/articles/cpp-streambufs
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_
|
||||
#define SHERPA_ONNX_CSRC_TEE_STREAM_H_
|
||||
#include <ostream>
|
||||
#include <streambuf>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
template <typename char_type, typename traits = std::char_traits<char_type>>
|
||||
class basic_teebuf : public std::basic_streambuf<char_type, traits> {
|
||||
public:
|
||||
using int_type = typename traits::int_type;
|
||||
|
||||
basic_teebuf(std::basic_streambuf<char_type, traits> *sb1,
|
||||
std::basic_streambuf<char_type, traits> *sb2)
|
||||
: sb1(sb1), sb2(sb2) {}
|
||||
|
||||
private:
|
||||
int sync() override {
|
||||
int const r1 = sb1->pubsync();
|
||||
int const r2 = sb2->pubsync();
|
||||
return r1 == 0 && r2 == 0 ? 0 : -1;
|
||||
}
|
||||
|
||||
int_type overflow(int_type c) override {
|
||||
int_type const eof = traits::eof();
|
||||
|
||||
if (traits::eq_int_type(c, eof)) {
|
||||
return traits::not_eof(c);
|
||||
} else {
|
||||
char_type const ch = traits::to_char_type(c);
|
||||
int_type const r1 = sb1->sputc(ch);
|
||||
int_type const r2 = sb2->sputc(ch);
|
||||
|
||||
return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof
|
||||
: c;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::basic_streambuf<char_type, traits> *sb1;
|
||||
std::basic_streambuf<char_type, traits> *sb2;
|
||||
};
|
||||
|
||||
using teebuf = basic_teebuf<char>;
|
||||
|
||||
class TeeStream : public std::ostream {
|
||||
public:
|
||||
TeeStream(std::ostream &o1, std::ostream &o2)
|
||||
: std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {}
|
||||
|
||||
private:
|
||||
teebuf tbuf;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_
|
||||
Reference in New Issue
Block a user