diff --git a/CMakeLists.txt b/CMakeLists.txt index ab57e664..b64d9241 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) +option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") +message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") +message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_CXX_EXTENSIONS OFF) @@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS) include(googletest) endif() +if(SHERPA_ONNX_ENABLE_WEBSOCKET) + include(websocketpp) + include(asio) +endif() + add_subdirectory(sherpa-onnx) if(SHERPA_ONNX_ENABLE_C_API) diff --git a/build-aarch64-linux-gnu.sh b/build-aarch64-linux-gnu.sh index d483eede..ee7234ad 100755 --- a/build-aarch64-linux-gnu.sh +++ b/build-aarch64-linux-gnu.sh @@ -40,6 +40,10 @@ cmake \ -DSHERPA_ONNX_ENABLE_TESTS=OFF \ -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=OFF \ + -DSHERPA_ONNX_ENABLE_C_API=OFF \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ .. diff --git a/build-android-x86-64.sh b/build-android-x86-64.sh index 4a05d2aa..e97a189b 100755 --- a/build-android-x86-64.sh +++ b/build-android-x86-64.sh @@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" -DSHERPA_ONNX_ENABLE_JNI=ON \ -DCMAKE_INSTALL_PREFIX=./install \ -DANDROID_ABI="x86_64" \ + -DSHERPA_ONNX_ENABLE_C_API=OFF \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ -DANDROID_PLATFORM=android-21 .. # make VERBOSE=1 -j4 diff --git a/cmake/asio.cmake b/cmake/asio.cmake new file mode 100644 index 00000000..8e6940e3 --- /dev/null +++ b/cmake/asio.cmake @@ -0,0 +1,39 @@ +function(download_asio) + include(FetchContent) + + set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz") + set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b") + + # If you don't have access to the Internet, + # please pre-download asio + set(possible_file_locations + $ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz + ${PROJECT_SOURCE_DIR}/asio-asio-1-24-0.tar.gz + ${PROJECT_BINARY_DIR}/asio-asio-1-24-0.tar.gz + /tmp/asio-asio-1-24-0.tar.gz + /star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(asio_URL "file://${f}") + break() + endif() + endforeach() + + FetchContent_Declare(asio + URL ${asio_URL} + URL_HASH ${asio_HASH} + ) + + FetchContent_GetProperties(asio) + if(NOT asio_POPULATED) + message(STATUS "Downloading asio ${asio_URL}") + FetchContent_Populate(asio) + endif() + message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}") + # add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL) + include_directories(${asio_SOURCE_DIR}/asio/include) +endfunction() + +download_asio() diff --git a/cmake/websocketpp.cmake b/cmake/websocketpp.cmake new file mode 100644 index 00000000..d00b96be --- /dev/null +++ b/cmake/websocketpp.cmake @@ -0,0 +1,40 @@ +function(download_websocketpp) + include(FetchContent) + + # The latest commit on the develop branch os as 2022-10-22 + set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip") + set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583") + + # If you don't have access to the Internet, + # please pre-download websocketpp + set(possible_file_locations + $ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ${PROJECT_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ${PROJECT_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + /tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + /star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(websocketpp_URL "file://${f}") + break() + endif() + endforeach() + + FetchContent_Declare(websocketpp + URL ${websocketpp_URL} + URL_HASH ${websocketpp_HASH} + ) + + FetchContent_GetProperties(websocketpp) + if(NOT websocketpp_POPULATED) + message(STATUS "Downloading websocketpp from ${websocketpp_URL}") + FetchContent_Populate(websocketpp) + endif() + message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}") + # add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL) + include_directories(${websocketpp_SOURCE_DIR}) +endfunction() + +download_websocketpp() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 8058346c..23044355 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -4,6 +4,7 @@ set(sources cat.cc endpoint.cc features.cc + file-utils.cc online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc @@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) install(TARGETS sherpa-onnx-microphone DESTINATION bin) endif() +if(SHERPA_ONNX_ENABLE_WEBSOCKET) + add_definitions(-DASIO_STANDALONE) + add_definitions(-D_WEBSOCKETPP_CPP11_STL_) + + add_executable(sherpa-onnx-online-websocket-server + online-websocket-server-impl.cc + online-websocket-server.cc + ) + target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) + + + add_executable(sherpa-onnx-online-websocket-client + online-websocket-client.cc + ) + target_link_libraries(sherpa-onnx-online-websocket-client sherpa-onnx-core) + + if(NOT WIN32) + target_link_libraries(sherpa-onnx-online-websocket-server -pthread) + target_compile_options(sherpa-onnx-online-websocket-server PRIVATE -Wno-deprecated-declarations) + + target_link_libraries(sherpa-onnx-online-websocket-client -pthread) + target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) + endif() + +endif() + if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs diff --git a/sherpa-onnx/csrc/CPPLINT.cfg b/sherpa-onnx/csrc/CPPLINT.cfg new file mode 100644 index 00000000..d0129441 --- /dev/null +++ b/sherpa-onnx/csrc/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=tee-stream.h diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 520ddfb7..ed90b58a 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -14,6 +14,15 @@ namespace sherpa_onnx { +void FeatureExtractorConfig::Register(ParseOptions *po) { + po->Register("sample-rate", &sampling_rate, + "Sampling rate of the input waveform. Must match the one " + "expected by the model."); + + po->Register("feat-dim", &feature_dim, + "Feature dimension. Must match the one expected by the model."); +} + std::string FeatureExtractorConfig::ToString() const { std::ostringstream os; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index bbe6e628..5f0ad967 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -9,6 +9,8 @@ #include #include +#include "sherpa-onnx/csrc/parse-options.h" + namespace sherpa_onnx { struct FeatureExtractorConfig { @@ -16,6 +18,8 @@ struct FeatureExtractorConfig { int32_t feature_dim = 80; std::string ToString() const; + + void Register(ParseOptions *po); }; class FeatureExtractor { diff --git a/sherpa-onnx/csrc/file-utils.cc b/sherpa-onnx/csrc/file-utils.cc new file mode 100644 index 00000000..8a64decf --- /dev/null +++ b/sherpa-onnx/csrc/file-utils.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/csrc/file-utils.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/file-utils.h" + +#include +#include + +#include "sherpa-onnx/csrc/log.h" + +namespace sherpa_onnx { + +bool FileExists(const std::string &filename) { + return std::ifstream(filename).good(); +} + +void AssertFileExists(const std::string &filename) { + if (!FileExists(filename)) { + SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!"; + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/file-utils.h b/sherpa-onnx/csrc/file-utils.h new file mode 100644 index 00000000..a41f6c9c --- /dev/null +++ b/sherpa-onnx/csrc/file-utils.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/file-utils.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_ +#define SHERPA_ONNX_CSRC_FILE_UTILS_H_ + +#include +#include + +namespace sherpa_onnx { + +/** Check whether a given path is a file or not + * + * @param filename Path to check. + * @return Return true if the given path is a file; return false otherwise. + */ +bool FileExists(const std::string &filename); + +/** Abort if the file does not exist. + * + * @param filename The file to check. + */ +void AssertFileExists(const std::string &filename); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_ diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 37177081..d6d6bfd7 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -12,6 +12,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, return ans; } +void OnlineRecognizerConfig::Register(ParseOptions *po) { + feat_config.Register(po); + model_config.Register(po); + endpoint_config.Register(po); + + po->Register("enable-endpoint", &enable_endpoint, + "True to enable endpoint detection. False to disable it."); +} + +bool OnlineRecognizerConfig::Validate() const { + return model_config.Validate(); +} + std::string OnlineRecognizerConfig::ToString() const { std::ostringstream os; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 6fe1ddda..cceadca6 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -17,11 +17,15 @@ #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { struct OnlineRecognizerResult { std::string text; + + // TODO(fangjun): Add a method to return a json string + std::string ToString() const { return text; } }; struct OnlineRecognizerConfig { @@ -41,6 +45,9 @@ struct OnlineRecognizerConfig { endpoint_config(endpoint_config), enable_endpoint(enable_endpoint) {} + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; }; diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index 9661f72b..83fb0978 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -5,8 +5,52 @@ #include +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { +void OnlineTransducerModelConfig::Register(ParseOptions *po) { + po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); + po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); + po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); + po->Register("tokens", &tokens, "Path to tokens.txt"); + po->Register("num_threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); +} + +bool OnlineTransducerModelConfig::Validate() const { + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); + return false; + } + + if (!FileExists(encoder_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); + return false; + } + + if (!FileExists(decoder_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); + return false; + } + + if (!FileExists(joiner_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); + return false; + } + + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + return true; +} + std::string OnlineTransducerModelConfig::ToString() const { std::ostringstream os; diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index 778f72c2..62c5d3d8 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -6,6 +6,8 @@ #include +#include "sherpa-onnx/csrc/parse-options.h" + namespace sherpa_onnx { struct OnlineTransducerModelConfig { @@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig { std::string decoder_filename; std::string joiner_filename; std::string tokens; - int32_t num_threads; + int32_t num_threads = 2; bool debug = false; OnlineTransducerModelConfig() = default; @@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig { num_threads(num_threads), debug(debug) {} + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; }; diff --git a/sherpa-onnx/csrc/online-websocket-client.cc b/sherpa-onnx/csrc/online-websocket-client.cc new file mode 100644 index 00000000..2df87b6c --- /dev/null +++ b/sherpa-onnx/csrc/online-websocket-client.cc @@ -0,0 +1,267 @@ +// sherpa/cpp_api/websocket/online-websocket-client.cc +// +// Copyright (c) 2022 Xiaomi Corporation +#include // NOLINT +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "websocketpp/client.hpp" +#include "websocketpp/config/asio_no_tls_client.hpp" +#include "websocketpp/uri.hpp" + +using client = websocketpp::client; + +using message_ptr = client::message_ptr; +using websocketpp::connection_hdl; + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-onnx using websocket. + +Usage: + +./bin/sherpa-onnx-online-websocket-client --help + +./bin/sherpa-onnx-online-websocket-client \ + --server-ip=127.0.0.1 \ + --server-port=6006 \ + --samples-per-message=8000 \ + --seconds-per-message=0.2 \ + /path/to/foo.wav + +It support only wave of with a single channel, 16kHz, 16-bit samples. +)"; + +class Client { + public: + Client(asio::io_context &io, // NOLINT + const std::string &ip, int16_t port, const std::vector &samples, + int32_t samples_per_message, float seconds_per_message) + : io_(io), + uri_(/*secure*/ false, ip, port, /*resource*/ "/"), + samples_(samples), + samples_per_message_(samples_per_message), + seconds_per_message_(seconds_per_message) { + c_.clear_access_channels(websocketpp::log::alevel::all); + // c_.set_access_channels(websocketpp::log::alevel::connect); + // c_.set_access_channels(websocketpp::log::alevel::disconnect); + + c_.init_asio(&io_); + c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + c_.set_close_handler( + [this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); }); + c_.set_message_handler( + [this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); }); + + Run(); + } + + private: + void Run() { + websocketpp::lib::error_code ec; + client::connection_ptr con = c_.get_connection(uri_.str(), ec); + if (ec) { + SHERPA_ONNX_LOGE("Could not create connection to %s because %s", + uri_.str().c_str(), ec.message().c_str()); + exit(EXIT_FAILURE); + } + + c_.connect(con); + } + + void OnOpen(connection_hdl hdl) { + auto start_time = std::chrono::steady_clock::now(); + asio::post( + io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); }); + } + + void OnMessage(connection_hdl hdl, message_ptr msg) { + const std::string &payload = msg->get_payload(); + + if (payload == "Done!") { + websocketpp::lib::error_code ec; + c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec); + if (ec) { + SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str()); + exit(EXIT_FAILURE); + } + } else { + SHERPA_ONNX_LOGE("%s", payload.c_str()); + } + } + + void SendMessage( + connection_hdl hdl, + std::chrono::time_point start_time) { + int32_t num_samples = samples_.size(); + int32_t num_messages = num_samples / samples_per_message_; + + websocketpp::lib::error_code ec; + auto time = std::chrono::steady_clock::now(); + int elapsed_time_ms = + std::chrono::duration_cast(time - start_time) + .count(); + + if (elapsed_time_ms < + static_cast(seconds_per_message_ * num_sent_messages_ * 1000)) { + std::this_thread::sleep_for(std::chrono::milliseconds(int( + seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms))); + } + + if (num_sent_messages_ < 1) { + SHERPA_ONNX_LOGE("Starting to send audio"); + } + + if (num_sent_messages_ < num_messages) { + c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_, + samples_per_message_ * sizeof(float), + websocketpp::frame::opcode::binary, ec); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + + ec.clear(); + + ++num_sent_messages_; + } + + if (num_sent_messages_ == num_messages) { + int32_t remaining_samples = num_samples % samples_per_message_; + if (remaining_samples) { + c_.send(hdl, + samples_.data() + num_sent_messages_ * samples_per_message_, + remaining_samples * sizeof(float), + websocketpp::frame::opcode::binary, ec); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + ec.clear(); + } + + // To signal that we have send all the messages + c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec); + SHERPA_ONNX_LOGE("Sent Done Signal"); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + } else { + asio::post(io_, [this, hdl, start_time]() { + this->SendMessage(hdl, start_time); + }); + } + } + + private: + client c_; + asio::io_context &io_; + websocketpp::uri uri_; + std::vector samples_; + int32_t samples_per_message_ = 8000; // 0.5 seconds + float seconds_per_message_ = 0.2; + int32_t num_sent_messages_ = 0; +}; + +int32_t main(int32_t argc, char *argv[]) { + std::string server_ip = "127.0.0.1"; + int32_t server_port = 6006; + + // Sample rate of the input wave. No resampling is made. + int32_t sample_rate = 16000; + int32_t samples_per_message = 8000; + float seconds_per_message = 0.2; + + sherpa_onnx::ParseOptions po(kUsageMessage); + + po.Register("server-ip", &server_ip, "IP address of the websocket server"); + po.Register("server-port", &server_port, "Port of the websocket server"); + po.Register("sample-rate", &sample_rate, + "Sample rate of the input wave. Should be the one expected by " + "the server"); + + po.Register("samples-per-message", &samples_per_message, + "Send this number of samples per message."); + + po.Register("seconds-per-message", &seconds_per_message, + "We will simulate that each message takes this number of seconds " + "to send. If you select a very large value, it will take a long " + "time to send all the samples"); + + po.Read(argc, argv); + + if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(), + server_ip.end())) { + SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str()); + return -1; + } + + if (server_port <= 0 || server_port > 65535) { + SHERPA_ONNX_LOGE("Invalid server port: %d", server_port); + return -1; + } + + // 0.01 is an arbitrary value. You can change it. + if (samples_per_message <= 0.01 * sample_rate) { + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", + samples_per_message); + return -1; + } + + // 100 is an arbitrary value. You can change it. + if (samples_per_message >= sample_rate * 100) { + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", + samples_per_message); + return -1; + } + + if (seconds_per_message < 0) { + SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f", + seconds_per_message); + return -1; + } + + // 1 is an arbitrary value. + if (seconds_per_message > 1) { + SHERPA_ONNX_LOGE( + "--seconds-per-message is too large: %.3f. You will wait a long time " + "to " + "send all the samples", + seconds_per_message); + return -1; + } + + if (po.NumArgs() != 1) { + po.PrintUsage(); + return -1; + } + + std::string wave_filename = po.GetArg(1); + + bool is_ok = false; + std::vector samples = + sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok); + + if (!is_ok) { + SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); + return -1; + } + + asio::io_context io_conn; // for network connections + Client c(io_conn, server_ip, server_port, samples, samples_per_message, + seconds_per_message); + + io_conn.run(); // will exit when the above connection is closed + + SHERPA_ONNX_LOGE("Done!"); + return 0; +} diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc new file mode 100644 index 00000000..7b267785 --- /dev/null +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -0,0 +1,327 @@ +// sherpa-onnx/csrc/online-websocket-server-impl.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-websocket-server-impl.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/log.h" + +namespace sherpa_onnx { + +void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) { + recognizer_config.Register(po); + + po->Register("loop-interval-ms", &loop_interval_ms, + "It determines how often the decoder loop runs. "); + + po->Register("max-batch-size", &max_batch_size, + "Max batch size for recognition."); +} + +void OnlineWebsocketDecoderConfig::Validate() const { + recognizer_config.Validate(); + SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); + SHERPA_ONNX_CHECK_GT(max_batch_size, 0); +} + +void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) { + decoder_config.Register(po); + + po->Register("log-file", &log_file, + "Path to the log file. Logs are " + "appended to this file"); +} + +void OnlineWebsocketServerConfig::Validate() const { + decoder_config.Validate(); +} + +OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server) + : server_(server), + config_(server->GetConfig().decoder_config), + timer_(server->GetWorkContext()) { + recognizer_ = std::make_unique(config_.recognizer_config); +} + +std::shared_ptr OnlineWebsocketDecoder::GetOrCreateConnection( + connection_hdl hdl) { + std::lock_guard lock(mutex_); + auto it = connections_.find(hdl); + if (it != connections_.end()) { + return it->second; + } else { + // create a new connection + std::shared_ptr s = recognizer_->CreateStream(); + auto c = std::make_shared(hdl, s); + connections_.insert({hdl, c}); + return c; + } +} + +void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr c) { + std::lock_guard lock(c->mutex); + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; + while (!c->samples.empty()) { + const auto &s = c->samples.front(); + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); + c->samples.pop_front(); + } +} + +void OnlineWebsocketDecoder::InputFinished(std::shared_ptr c) { + std::lock_guard lock(c->mutex); + + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; + + while (!c->samples.empty()) { + const auto &s = c->samples.front(); + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); + c->samples.pop_front(); + } + + // TODO(fangjun): Change the amount of paddings to be configurable + std::vector tail_padding(static_cast(0.8 * sample_rate)); + + c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); + + c->s->InputFinished(); + c->eof = true; +} + +void OnlineWebsocketDecoder::Run() { + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); + + timer_.async_wait( + [this](const asio::error_code &ec) { ProcessConnections(ec); }); +} + +void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) { + if (ec) { + SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!"; + } + + std::lock_guard lock(mutex_); + std::vector to_remove; + for (auto &p : connections_) { + auto hdl = p.first; + auto c = p.second; + + // The order of `if` below matters! + if (!server_->Contains(hdl)) { + // If the connection is disconnected, we stop processing it + to_remove.push_back(hdl); + continue; + } + + if (active_.count(hdl)) { + // Another thread is decoding this stream, so skip it + continue; + } + + if (!recognizer_->IsReady(c->s.get()) && !c->eof) { + // this stream has not enough frames to decode, so skip it + continue; + } + + if (!recognizer_->IsReady(c->s.get()) && c->eof) { + // We won't receive samples from the client, so send a Done! to client + + asio::post(server_->GetWorkContext(), + [this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); }); + + to_remove.push_back(hdl); + continue; + } + + // TODO(fangun): If the connection is timed out, we need to also + // add it to `to_remove` + + // this stream has enough frames and is currently not processed by any + // threads, so put it into the ready queue + ready_connections_.push_back(c); + + // In `Decode()`, it will remove hdl from `active_` + active_.insert(c->hdl); + } + + for (auto hdl : to_remove) { + connections_.erase(hdl); + } + + if (!ready_connections_.empty()) { + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); + } + + // Schedule another call + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); + + timer_.async_wait( + [this](const asio::error_code &ec) { ProcessConnections(ec); }); +} + +void OnlineWebsocketDecoder::Decode() { + std::unique_lock lock(mutex_); + if (ready_connections_.empty()) { + // There are no connections that are ready for decoding, + // so we return directly + return; + } + + std::vector> c_vec; + std::vector s_vec; + while (!ready_connections_.empty() && + static_cast(s_vec.size()) < config_.max_batch_size) { + auto c = ready_connections_.front(); + ready_connections_.pop_front(); + + c_vec.push_back(c); + s_vec.push_back(c->s.get()); + } + + if (!ready_connections_.empty()) { + // there are too many ready connections but this thread can only handle + // max_batch_size connections at a time, so we schedule another call + // to Decode() and let other threads to process the ready connections + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); + } + + lock.unlock(); + recognizer_->DecodeStreams(s_vec.data(), s_vec.size()); + lock.lock(); + + for (auto c : c_vec) { + auto result = recognizer_->GetResult(c->s.get()); + + asio::post(server_->GetConnectionContext(), + [this, hdl = c->hdl, str = result.ToString()]() { + server_->Send(hdl, str); + }); + active_.erase(c->hdl); + } +} + +OnlineWebsocketServer::OnlineWebsocketServer( + asio::io_context &io_conn, asio::io_context &io_work, + const OnlineWebsocketServerConfig &config) + : config_(config), + io_conn_(io_conn), + io_work_(io_work), + log_(config.log_file, std::ios::app), + tee_(std::cout, log_), + decoder_(this) { + SetupLog(); + + server_.init_asio(&io_conn_); + + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); + + server_.set_message_handler( + [this](connection_hdl hdl, server::message_ptr msg) { + OnMessage(hdl, msg); + }); +} + +void OnlineWebsocketServer::Run(uint16_t port) { + server_.set_reuse_addr(true); + server_.listen(asio::ip::tcp::v4(), port); + server_.start_accept(); + decoder_.Run(); +} + +void OnlineWebsocketServer::SetupLog() { + server_.clear_access_channels(websocketpp::log::alevel::all); + // server_.set_access_channels(websocketpp::log::alevel::connect); + // server_.set_access_channels(websocketpp::log::alevel::disconnect); + + // So that it also prints to std::cout and std::cerr + server_.get_alog().set_ostream(&tee_); + server_.get_elog().set_ostream(&tee_); +} + +void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) { + websocketpp::lib::error_code ec; + if (!Contains(hdl)) { + return; + } + + server_.send(hdl, text, websocketpp::frame::opcode::text, ec); + if (ec) { + server_.get_alog().write(websocketpp::log::alevel::app, ec.message()); + } +} + +void OnlineWebsocketServer::OnOpen(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.insert(hdl); + + std::ostringstream os; + os << "New connection: " + << server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". " + << "Number of active connections: " << connections_.size() << ".\n"; + SHERPA_ONNX_LOG(INFO) << os.str(); +} + +void OnlineWebsocketServer::OnClose(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.erase(hdl); + + SHERPA_ONNX_LOG(INFO) << "Number of active connections: " + << connections_.size() << "\n"; +} + +bool OnlineWebsocketServer::Contains(connection_hdl hdl) const { + std::lock_guard lock(mutex_); + return connections_.count(hdl); +} + +void OnlineWebsocketServer::OnMessage(connection_hdl hdl, + server::message_ptr msg) { + auto c = decoder_.GetOrCreateConnection(hdl); + + const std::string &payload = msg->get_payload(); + + switch (msg->get_opcode()) { + case websocketpp::frame::opcode::text: + if (payload == "Done") { + asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); }); + } + break; + case websocketpp::frame::opcode::binary: { + auto p = reinterpret_cast(payload.data()); + int32_t num_samples = payload.size() / sizeof(float); + std::vector samples(p, p + num_samples); + + c->samples.push_back(std::move(samples)); + + asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); }); + break; + } + default: + break; + } +} + +void OnlineWebsocketServer::Close(connection_hdl hdl, + websocketpp::close::status::value code, + const std::string &reason) { + auto con = server_.get_con_from_hdl(hdl); + + std::ostringstream os; + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason + << "\n"; + + websocketpp::lib::error_code ec; + server_.close(hdl, code, reason, ec); + if (ec) { + os << "Failed to close" << con->get_remote_endpoint() << ". " + << ec.message() << "\n"; + } + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.h b/sherpa-onnx/csrc/online-websocket-server-impl.h new file mode 100644 index 00000000..a82170fb --- /dev/null +++ b/sherpa-onnx/csrc/online-websocket-server-impl.h @@ -0,0 +1,177 @@ +// sherpa-onnx/csrc/online-websocket-server-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "asio.hpp" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-stream.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/tee-stream.h" +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS +#include "websocketpp/server.hpp" +using server = websocketpp::server; +using connection_hdl = websocketpp::connection_hdl; + +namespace sherpa_onnx { + +struct Connection { + // handle to the connection. We can use it to send messages to the client + connection_hdl hdl; + std::shared_ptr s; + + // set it to true when InputFinished() is called + bool eof = false; + + // The last time we received a message from the client + // TODO(fangjun): Use it to disconnect from a client if it is inactive + // for a specified time. + std::chrono::steady_clock::time_point last_active; + + std::mutex mutex; // protect samples + + // Audio samples received from the client. + // + // The I/O threads receive audio samples into this queue + // and invoke work threads to compute features + std::deque> samples; + + Connection() = default; + Connection(connection_hdl hdl, std::shared_ptr s) + : hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {} +}; + +struct OnlineWebsocketDecoderConfig { + OnlineRecognizerConfig recognizer_config; + + // It determines how often the decoder loop runs. + int32_t loop_interval_ms = 10; + + int32_t max_batch_size = 5; + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OnlineWebsocketServer; + +class OnlineWebsocketDecoder { + public: + /** + * @param server Not owned. + */ + explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server); + + std::shared_ptr GetOrCreateConnection(connection_hdl hdl); + + // Compute features for a stream given audio samples + void AcceptWaveform(std::shared_ptr c); + + // signal that there will be no more audio samples for a stream + void InputFinished(std::shared_ptr c); + + void Run(); + + private: + void ProcessConnections(const asio::error_code &ec); + + /** It is called by one of the worker thread. + */ + void Decode(); + + private: + OnlineWebsocketServer *server_; // not owned + std::unique_ptr recognizer_; + OnlineWebsocketDecoderConfig config_; + asio::steady_timer timer_; + + // It protects `connections_`, `ready_connections_`, and `active_` + std::mutex mutex_; + + std::map, + std::owner_less> + connections_; + + // Whenever a connection has enough feature frames for decoding, we put + // it in this queue + std::deque> ready_connections_; + + // If we are decoding a stream, we put it in the active_ set so that + // only one thread can decode a stream at a time. + std::set> active_; +}; + +struct OnlineWebsocketServerConfig { + OnlineWebsocketDecoderConfig decoder_config; + + std::string log_file = "./log.txt"; + + void Register(sherpa_onnx::ParseOptions *po); + void Validate() const; +}; + +class OnlineWebsocketServer { + public: + explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OnlineWebsocketServerConfig &config); + + void Run(uint16_t port); + + const OnlineWebsocketServerConfig &GetConfig() const { return config_; } + asio::io_context &GetConnectionContext() { return io_conn_; } + asio::io_context &GetWorkContext() { return io_work_; } + server &GetServer() { return server_; } + + void Send(connection_hdl hdl, const std::string &text); + + bool Contains(connection_hdl hdl) const; + + private: + void SetupLog(); + + // When a websocket client is connected, it will invoke this method + // (Not for HTTP) + void OnOpen(connection_hdl hdl); + + // When a websocket client is disconnected, it will invoke this method + void OnClose(connection_hdl hdl); + + void OnMessage(connection_hdl hdl, server::message_ptr msg); + + // Close a websocket connection with given code and reason + void Close(connection_hdl hdl, websocketpp::close::status::value code, + const std::string &reason); + + private: + OnlineWebsocketServerConfig config_; + asio::io_context &io_conn_; + asio::io_context &io_work_; + server server_; + + std::ofstream log_; + sherpa_onnx::TeeStream tee_; + + OnlineWebsocketDecoder decoder_; + + mutable std::mutex mutex_; + + std::set> connections_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-websocket-server.cc b/sherpa-onnx/csrc/online-websocket-server.cc new file mode 100644 index 00000000..274f0344 --- /dev/null +++ b/sherpa-onnx/csrc/online-websocket-server.cc @@ -0,0 +1,108 @@ +// sherpa-onnx/csrc/online-websocket-server.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "asio.hpp" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-websocket-server-impl.h" +#include "sherpa-onnx/csrc/parse-options.h" + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-onnx using websocket. + +Usage: + +./bin/sherpa-onnx-online-websocket-server --help + +./bin/sherpa-onnx-online-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 \ + --loop-interval-ms=10 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)"; + +int32_t main(int32_t argc, char *argv[]) { + sherpa_onnx::ParseOptions po(kUsageMessage); + + sherpa_onnx::OnlineWebsocketServerConfig config; + + // the server will listen on this port + int32_t port = 6006; + + // size of the thread pool for handling network connections + int32_t num_io_threads = 1; + + // size of the thread pool for neural network computation and decoding + int32_t num_work_threads = 3; + + po.Register("num-io-threads", &num_io_threads, + "Thread pool size for network connections."); + + po.Register("num-work-threads", &num_work_threads, + "Thread pool size for for neural network " + "computation and decoding."); + + po.Register("port", &port, "The port on which the server will listen."); + + config.Register(&po); + + if (argc == 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + po.Read(argc, argv); + + if (po.NumArgs() != 0) { + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + config.Validate(); + + asio::io_context io_conn; // for network connections + asio::io_context io_work; // for neural network and decoding + + sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); + server.Run(port); + + SHERPA_ONNX_LOGE("Listening on: %d", port); + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); + + // give some work to do for the io_work pool + auto work_guard = asio::make_work_guard(io_work); + + std::vector io_threads; + + // decrement since the main thread is also used for network communications + for (int32_t i = 0; i < num_io_threads - 1; ++i) { + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); + } + + std::vector work_threads; + for (int32_t i = 0; i < num_work_threads; ++i) { + work_threads.emplace_back([&io_work]() { io_work.run(); }); + } + + io_conn.run(); + + for (auto &t : io_threads) { + t.join(); + } + + for (auto &t : work_threads) { + t.join(); + } + + return 0; +} diff --git a/sherpa-onnx/csrc/tee-stream.h b/sherpa-onnx/csrc/tee-stream.h new file mode 100644 index 00000000..6fbd3d29 --- /dev/null +++ b/sherpa-onnx/csrc/tee-stream.h @@ -0,0 +1,61 @@ +// Code in this file is copied and modified from +// https://wordaligned.org/articles/cpp-streambufs + +#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_ +#define SHERPA_ONNX_CSRC_TEE_STREAM_H_ +#include +#include +#include + +namespace sherpa_onnx { + +template > +class basic_teebuf : public std::basic_streambuf { + public: + using int_type = typename traits::int_type; + + basic_teebuf(std::basic_streambuf *sb1, + std::basic_streambuf *sb2) + : sb1(sb1), sb2(sb2) {} + + private: + int sync() override { + int const r1 = sb1->pubsync(); + int const r2 = sb2->pubsync(); + return r1 == 0 && r2 == 0 ? 0 : -1; + } + + int_type overflow(int_type c) override { + int_type const eof = traits::eof(); + + if (traits::eq_int_type(c, eof)) { + return traits::not_eof(c); + } else { + char_type const ch = traits::to_char_type(c); + int_type const r1 = sb1->sputc(ch); + int_type const r2 = sb2->sputc(ch); + + return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof + : c; + } + } + + private: + std::basic_streambuf *sb1; + std::basic_streambuf *sb2; +}; + +using teebuf = basic_teebuf; + +class TeeStream : public std::ostream { + public: + TeeStream(std::ostream &o1, std::ostream &o2) + : std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {} + + private: + teebuf tbuf; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_