This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/sherpa-onnx/csrc/online-websocket-client.cc
2024-04-24 18:41:48 +08:00

275 lines
8.3 KiB
C++

// sherpa/cpp_api/websocket/online-websocket-client.cc
//
// Copyright (c) 2022 Xiaomi Corporation
#include <chrono> // NOLINT
#include <fstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "websocketpp/client.hpp"
#include "websocketpp/config/asio_no_tls_client.hpp"
#include "websocketpp/uri.hpp"
using client = websocketpp::client<websocketpp::config::asio_client>;
using message_ptr = client::message_ptr;
using websocketpp::connection_hdl;
static constexpr const char *kUsageMessage = R"(
Automatic speech recognition with sherpa-onnx using websocket.
Usage:
./bin/sherpa-onnx-online-websocket-client --help
./bin/sherpa-onnx-online-websocket-client \
--server-ip=127.0.0.1 \
--server-port=6006 \
--samples-per-message=8000 \
--seconds-per-message=0.2 \
/path/to/foo.wav
It support only wave of with a single channel, 16kHz, 16-bit samples.
)";
class Client {
public:
Client(asio::io_context &io, // NOLINT
const std::string &ip, int16_t port, const std::vector<float> &samples,
int32_t samples_per_message, float seconds_per_message)
: io_(io),
uri_(/*secure*/ false, ip, port, /*resource*/ "/"),
samples_(samples),
samples_per_message_(samples_per_message),
seconds_per_message_(seconds_per_message) {
c_.clear_access_channels(websocketpp::log::alevel::all);
// c_.set_access_channels(websocketpp::log::alevel::connect);
// c_.set_access_channels(websocketpp::log::alevel::disconnect);
c_.init_asio(&io_);
c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
c_.set_close_handler(
[this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); });
c_.set_message_handler(
[this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); });
Run();
}
private:
void Run() {
websocketpp::lib::error_code ec;
client::connection_ptr con = c_.get_connection(uri_.str(), ec);
if (ec) {
SHERPA_ONNX_LOGE("Could not create connection to %s because %s",
uri_.str().c_str(), ec.message().c_str());
exit(EXIT_FAILURE);
}
c_.connect(con);
}
void OnOpen(connection_hdl hdl) {
auto start_time = std::chrono::steady_clock::now();
asio::post(
io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); });
}
void OnMessage(connection_hdl hdl, message_ptr msg) {
const std::string &payload = msg->get_payload();
if (payload == "Done!") {
websocketpp::lib::error_code ec;
c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str());
exit(EXIT_FAILURE);
}
} else {
SHERPA_ONNX_LOGE("%s", payload.c_str());
}
}
void SendMessage(
connection_hdl hdl,
std::chrono::time_point<std::chrono::steady_clock> start_time) {
int32_t num_samples = samples_.size();
int32_t num_messages = num_samples / samples_per_message_;
websocketpp::lib::error_code ec;
auto time = std::chrono::steady_clock::now();
int elapsed_time_ms =
std::chrono::duration_cast<std::chrono::milliseconds>(time - start_time)
.count();
if (elapsed_time_ms <
static_cast<int>(seconds_per_message_ * num_sent_messages_ * 1000)) {
std::this_thread::sleep_for(std::chrono::milliseconds(int(
seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms)));
}
if (num_sent_messages_ < 1) {
SHERPA_ONNX_LOGE("Starting to send audio");
}
if (num_sent_messages_ < num_messages) {
c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_,
samples_per_message_ * sizeof(float),
websocketpp::frame::opcode::binary, ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
ec.clear();
++num_sent_messages_;
}
if (num_sent_messages_ == num_messages) {
int32_t remaining_samples = num_samples % samples_per_message_;
if (remaining_samples) {
c_.send(hdl,
samples_.data() + num_sent_messages_ * samples_per_message_,
remaining_samples * sizeof(float),
websocketpp::frame::opcode::binary, ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
ec.clear();
}
// To signal that we have send all the messages
c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec);
SHERPA_ONNX_LOGE("Sent Done Signal");
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
} else {
asio::post(io_, [this, hdl, start_time]() {
this->SendMessage(hdl, start_time);
});
}
}
private:
client c_;
asio::io_context &io_;
websocketpp::uri uri_;
std::vector<float> samples_;
int32_t samples_per_message_ = 8000; // 0.5 seconds
float seconds_per_message_ = 0.2;
int32_t num_sent_messages_ = 0;
};
int32_t main(int32_t argc, char *argv[]) {
std::string server_ip = "127.0.0.1";
int32_t server_port = 6006;
// Sample rate of the input wave. No resampling is made.
int32_t sample_rate = 16000;
int32_t samples_per_message = 8000;
float seconds_per_message = 0.2;
sherpa_onnx::ParseOptions po(kUsageMessage);
po.Register("server-ip", &server_ip, "IP address of the websocket server");
po.Register("server-port", &server_port, "Port of the websocket server");
po.Register("sample-rate", &sample_rate,
"Sample rate of the input wave. Should be the one expected by "
"the server");
po.Register("samples-per-message", &samples_per_message,
"Send this number of samples per message.");
po.Register("seconds-per-message", &seconds_per_message,
"We will simulate that each message takes this number of seconds "
"to send. If you select a very large value, it will take a long "
"time to send all the samples");
po.Read(argc, argv);
if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(),
server_ip.end())) {
SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str());
return -1;
}
if (server_port <= 0 || server_port > 65535) {
SHERPA_ONNX_LOGE("Invalid server port: %d", server_port);
return -1;
}
// 0.01 is an arbitrary value. You can change it.
if (samples_per_message <= 0.01 * sample_rate) {
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
samples_per_message);
return -1;
}
// 100 is an arbitrary value. You can change it.
if (samples_per_message >= sample_rate * 100) {
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
samples_per_message);
return -1;
}
if (seconds_per_message < 0) {
SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f",
seconds_per_message);
return -1;
}
// 1 is an arbitrary value.
if (seconds_per_message > 1) {
SHERPA_ONNX_LOGE(
"--seconds-per-message is too large: %.3f. You will wait a long time "
"to "
"send all the samples",
seconds_per_message);
return -1;
}
if (po.NumArgs() != 1) {
po.PrintUsage();
return -1;
}
std::string wave_filename = po.GetArg(1);
bool is_ok = false;
int32_t actual_sample_rate = -1;
std::vector<float> samples =
sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok);
if (!is_ok) {
SHERPA_ONNX_LOGE("Failed to read '%s'", wave_filename.c_str());
return -1;
}
if (actual_sample_rate != sample_rate) {
SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate,
actual_sample_rate);
return -1;
}
asio::io_context io_conn; // for network connections
Client c(io_conn, server_ip, server_port, samples, samples_per_message,
seconds_per_message);
io_conn.run(); // will exit when the above connection is closed
SHERPA_ONNX_LOGE("Done!");
return 0;
}