// sherpa-onnx/csrc/online-websocket-server-impl.cc // // Copyright (c) 2022-2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-websocket-server-impl.h" #include #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/log.h" namespace sherpa_onnx { void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) { recognizer_config.Register(po); po->Register("loop-interval-ms", &loop_interval_ms, "It determines how often the decoder loop runs. "); po->Register("max-batch-size", &max_batch_size, "Max batch size for recognition."); } void OnlineWebsocketDecoderConfig::Validate() const { recognizer_config.Validate(); SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); SHERPA_ONNX_CHECK_GT(max_batch_size, 0); } void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) { decoder_config.Register(po); po->Register("log-file", &log_file, "Path to the log file. Logs are " "appended to this file"); } void OnlineWebsocketServerConfig::Validate() const { decoder_config.Validate(); } OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server) : server_(server), config_(server->GetConfig().decoder_config), timer_(server->GetWorkContext()) { recognizer_ = std::make_unique(config_.recognizer_config); } std::shared_ptr OnlineWebsocketDecoder::GetOrCreateConnection( connection_hdl hdl) { std::lock_guard lock(mutex_); auto it = connections_.find(hdl); if (it != connections_.end()) { return it->second; } else { // create a new connection std::shared_ptr s = recognizer_->CreateStream(); auto c = std::make_shared(hdl, s); connections_.insert({hdl, c}); return c; } } void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr c) { std::lock_guard lock(c->mutex); float sample_rate = config_.recognizer_config.feat_config.sampling_rate; while (!c->samples.empty()) { const auto &s = c->samples.front(); c->s->AcceptWaveform(sample_rate, s.data(), s.size()); c->samples.pop_front(); } } void OnlineWebsocketDecoder::InputFinished(std::shared_ptr c) { std::lock_guard lock(c->mutex); float sample_rate = config_.recognizer_config.feat_config.sampling_rate; while (!c->samples.empty()) { const auto &s = c->samples.front(); c->s->AcceptWaveform(sample_rate, s.data(), s.size()); c->samples.pop_front(); } // TODO(fangjun): Change the amount of paddings to be configurable std::vector tail_padding(static_cast(0.8 * sample_rate)); c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); c->s->InputFinished(); c->eof = true; } void OnlineWebsocketDecoder::Run() { timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); timer_.async_wait( [this](const asio::error_code &ec) { ProcessConnections(ec); }); } void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) { if (ec) { SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!"; } std::lock_guard lock(mutex_); std::vector to_remove; for (auto &p : connections_) { auto hdl = p.first; auto c = p.second; // The order of `if` below matters! if (!server_->Contains(hdl)) { // If the connection is disconnected, we stop processing it to_remove.push_back(hdl); continue; } if (active_.count(hdl)) { // Another thread is decoding this stream, so skip it continue; } if (!recognizer_->IsReady(c->s.get()) && !c->eof) { // this stream has not enough frames to decode, so skip it continue; } if (!recognizer_->IsReady(c->s.get()) && c->eof) { // We won't receive samples from the client, so send a Done! to client asio::post(server_->GetWorkContext(), [this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); }); to_remove.push_back(hdl); continue; } // TODO(fangun): If the connection is timed out, we need to also // add it to `to_remove` // this stream has enough frames and is currently not processed by any // threads, so put it into the ready queue ready_connections_.push_back(c); // In `Decode()`, it will remove hdl from `active_` active_.insert(c->hdl); } for (auto hdl : to_remove) { connections_.erase(hdl); } if (!ready_connections_.empty()) { asio::post(server_->GetWorkContext(), [this]() { Decode(); }); } // Schedule another call timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); timer_.async_wait( [this](const asio::error_code &ec) { ProcessConnections(ec); }); } void OnlineWebsocketDecoder::Decode() { std::unique_lock lock(mutex_); if (ready_connections_.empty()) { // There are no connections that are ready for decoding, // so we return directly return; } std::vector> c_vec; std::vector s_vec; while (!ready_connections_.empty() && static_cast(s_vec.size()) < config_.max_batch_size) { auto c = ready_connections_.front(); ready_connections_.pop_front(); c_vec.push_back(c); s_vec.push_back(c->s.get()); } if (!ready_connections_.empty()) { // there are too many ready connections but this thread can only handle // max_batch_size connections at a time, so we schedule another call // to Decode() and let other threads to process the ready connections asio::post(server_->GetWorkContext(), [this]() { Decode(); }); } lock.unlock(); recognizer_->DecodeStreams(s_vec.data(), s_vec.size()); lock.lock(); for (auto c : c_vec) { auto result = recognizer_->GetResult(c->s.get()); if (recognizer_->IsEndpoint(c->s.get())) { result.is_final = true; recognizer_->Reset(c->s.get()); } if (!recognizer_->IsReady(c->s.get()) && c->eof) { result.is_final = true; } asio::post(server_->GetConnectionContext(), [this, hdl = c->hdl, str = result.AsJsonString()]() { server_->Send(hdl, str); }); active_.erase(c->hdl); } } OnlineWebsocketServer::OnlineWebsocketServer( asio::io_context &io_conn, asio::io_context &io_work, const OnlineWebsocketServerConfig &config) : config_(config), io_conn_(io_conn), io_work_(io_work), log_(config.log_file, std::ios::app), tee_(std::cout, log_), decoder_(this) { SetupLog(); server_.init_asio(&io_conn_); server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); server_.set_message_handler( [this](connection_hdl hdl, server::message_ptr msg) { OnMessage(hdl, msg); }); } void OnlineWebsocketServer::Run(uint16_t port) { server_.set_reuse_addr(true); server_.listen(asio::ip::tcp::v4(), port); server_.start_accept(); decoder_.Run(); } void OnlineWebsocketServer::SetupLog() { server_.clear_access_channels(websocketpp::log::alevel::all); // server_.set_access_channels(websocketpp::log::alevel::connect); // server_.set_access_channels(websocketpp::log::alevel::disconnect); // So that it also prints to std::cout and std::cerr server_.get_alog().set_ostream(&tee_); server_.get_elog().set_ostream(&tee_); } void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) { websocketpp::lib::error_code ec; if (!Contains(hdl)) { return; } server_.send(hdl, text, websocketpp::frame::opcode::text, ec); if (ec) { server_.get_alog().write(websocketpp::log::alevel::app, ec.message()); } } void OnlineWebsocketServer::OnOpen(connection_hdl hdl) { std::lock_guard lock(mutex_); connections_.insert(hdl); std::ostringstream os; os << "New connection: " << server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". " << "Number of active connections: " << connections_.size() << ".\n"; SHERPA_ONNX_LOG(INFO) << os.str(); } void OnlineWebsocketServer::OnClose(connection_hdl hdl) { std::lock_guard lock(mutex_); connections_.erase(hdl); SHERPA_ONNX_LOG(INFO) << "Number of active connections: " << connections_.size() << "\n"; } bool OnlineWebsocketServer::Contains(connection_hdl hdl) const { std::lock_guard lock(mutex_); return connections_.count(hdl); } void OnlineWebsocketServer::OnMessage(connection_hdl hdl, server::message_ptr msg) { auto c = decoder_.GetOrCreateConnection(hdl); const std::string &payload = msg->get_payload(); switch (msg->get_opcode()) { case websocketpp::frame::opcode::text: if (payload == "Done") { asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); }); } break; case websocketpp::frame::opcode::binary: { auto p = reinterpret_cast(payload.data()); int32_t num_samples = payload.size() / sizeof(float); std::vector samples(p, p + num_samples); c->samples.push_back(std::move(samples)); asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); }); break; } default: break; } } void OnlineWebsocketServer::Close(connection_hdl hdl, websocketpp::close::status::value code, const std::string &reason) { auto con = server_.get_con_from_hdl(hdl); std::ostringstream os; os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason << "\n"; websocketpp::lib::error_code ec; server_.close(hdl, code, reason, ec); if (ec) { os << "Failed to close" << con->get_remote_endpoint() << ". " << ec.message() << "\n"; } server_.get_alog().write(websocketpp::log::alevel::app, os.str()); } } // namespace sherpa_onnx