diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 0706c57c0..a381cdcc7 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,11 +1,15 @@ use crate::router::PolicyConfig; use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use actix_web::{ + error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, +}; use bytes::Bytes; use env_logger::Builder; +use futures_util::StreamExt; use log::{info, LevelFilter}; use std::collections::HashMap; use std::io::Write; +use std::time::Duration; #[derive(Debug)] pub struct AppState { @@ -25,6 +29,22 @@ impl AppState { } } +async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result { + // Drain the payload + while let Some(chunk) = payload.next().await { + if let Err(err) = chunk { + println!("Error while draining payload: {:?}", err); + break; + } + } + Ok(HttpResponse::NotFound().finish()) +} + +// Custom error handler for JSON payload errors. +fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error { + error::ErrorPayloadTooLarge("Payload too large") +} + #[get("/health")] async fn health(req: HttpRequest, data: web::Data) -> impl Responder { data.router @@ -162,6 +182,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ); let client = reqwest::Client::builder() + .pool_idle_timeout(Some(Duration::from_secs(50))) .build() .expect("Failed to create HTTP client"); @@ -180,7 +201,11 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { HttpServer::new(move || { App::new() .app_data(app_state.clone()) - .app_data(web::JsonConfig::default().limit(config.max_payload_size)) + .app_data( + web::JsonConfig::default() + .limit(config.max_payload_size) + .error_handler(json_error_handler), + ) .app_data(web::PayloadConfig::default().limit(config.max_payload_size)) .service(generate) .service(v1_chat_completions) @@ -192,6 +217,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(get_server_info) .service(add_worker) .service(remove_worker) + // Default handler for unmatched routes. + .default_service(web::route().to(sink_handler)) }) .bind((config.host, config.port))? .run()