[router] migrate router from actix to axum (#8479)

This commit is contained in:
Simo Lin
2025-07-30 17:47:19 -07:00
committed by GitHub
parent 299803343d
commit 66a398f49d
18 changed files with 3626 additions and 3549 deletions

View File

@@ -1,285 +1,169 @@
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::{get_request_id, RequestIdMiddleware};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
use axum::{
extract::{Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::signal;
use tokio::spawn;
use tracing::{error, info, warn, Level};
#[derive(Debug)]
#[derive(Clone)]
pub struct AppState {
router: Arc<dyn RouterTrait>,
client: Client,
pub router: Arc<dyn RouterTrait>,
pub client: Client,
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
}
impl AppState {
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
// Use RouterFactory to create the appropriate router type
pub fn new(
router_config: RouterConfig,
client: Client,
max_concurrent_requests: usize,
) -> Result<Self, String> {
let router = RouterFactory::create_router(&router_config)?;
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
let router = Arc::from(router);
Ok(Self { router, client })
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
Ok(Self {
router,
client,
_concurrency_limiter: concurrency_limiter,
})
}
}
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
// Drain the payload
while let Some(chunk) = payload.next().await {
if let Err(err) = chunk {
println!("Error while draining payload: {:?}", err);
break;
}
}
Ok(HttpResponse::NotFound().finish())
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response()
}
// Custom error handler for JSON payload errors.
fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
let request_id = get_request_id(req);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
request_id = %request_id,
"Payload too large length={} limit={}", length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!(
request_id = %request_id,
"Payload overflow limit={}", limit
);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => {
error!(
request_id = %request_id,
"Invalid JSON payload error={}", err
);
error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
}
}
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
state.router.liveness()
}
#[get("/liveness")]
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.liveness()
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
state.router.readiness()
}
#[get("/readiness")]
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.readiness()
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health(&state.client, req).await
}
#[get("/health")]
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.health(&data.client, &req).await
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health_generate(&state.client, req).await
}
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.health_generate(&data.client, &req).await
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_server_info(&state.client, req).await
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_server_info(&data.client, &req).await
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_models(&state.client, req).await
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_models(&data.client, &req).await
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_model_info(&state.client, req).await
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_model_info(&data.client, &req).await
}
#[post("/generate")]
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
async fn generate(
req: HttpRequest,
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received generate request method=\"POST\" path=\"/generate\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse generate request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<GenerateRequest>,
) -> Response {
state
.router
.route_generate(&state.client, &req, json_body)
.await)
.route_generate(&state.client, Some(&headers), &body)
.await
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse chat completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
) -> Response {
state
.router
.route_chat(&state.client, &req, json_body)
.await)
.route_chat(&state.client, Some(&headers), &body)
.await
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received completion request method=\"POST\" path=\"/v1/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<CompletionRequest>,
) -> Response {
state
.router
.route_completion(&state.client, &req, json_body)
.await)
.route_completion(&state.client, Some(&headers), &body)
.await
}
#[post("/add_worker")]
// Worker management endpoints
async fn add_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") {
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => {
warn!(
request_id = %request_id,
"Add worker request missing URL parameter"
);
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter");
return (
StatusCode::BAD_REQUEST,
"Worker URL required. Provide 'url' query parameter",
)
.into_response();
}
};
info!(
request_id = %request_id,
worker_url = %worker_url,
"Adding worker"
);
match data.router.add_worker(&worker_url).await {
Ok(message) => {
info!(
request_id = %request_id,
worker_url = %worker_url,
"Successfully added worker"
);
HttpResponse::Ok().body(message)
}
Err(error) => {
error!(
request_id = %request_id,
worker_url = %worker_url,
error = %error,
"Failed to add worker"
);
HttpResponse::BadRequest().body(error)
}
match state.router.add_worker(&worker_url).await {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
let worker_list = data.router.get_worker_urls();
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
let worker_list = state.router.get_worker_urls();
Json(serde_json::json!({ "urls": worker_list })).into_response()
}
#[post("/remove_worker")]
async fn remove_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") {
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => {
warn!(
request_id = %request_id,
"Remove worker request missing URL parameter"
);
return HttpResponse::BadRequest().finish();
}
None => return StatusCode::BAD_REQUEST.into_response(),
};
info!(
request_id = %request_id,
worker_url = %worker_url,
"Removing worker"
);
data.router.remove_worker(&worker_url);
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
state.router.remove_worker(&worker_url);
(
StatusCode::OK,
format!("Successfully removed worker: {}", worker_url),
)
.into_response()
}
#[post("/flush_cache")]
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.flush_cache(&data.client).await
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.flush_cache(&state.client).await
}
#[get("/get_loads")]
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_worker_loads(&data.client).await
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads(&state.client).await
}
pub struct ServerConfig {
@@ -295,7 +179,58 @@ pub struct ServerConfig {
pub request_id_headers: Option<Vec<String>>,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
/// Build the Axum application with all routes and middleware
pub fn build_app(
app_state: Arc<AppState>,
max_payload_size: usize,
request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>,
) -> Router {
// Create routes
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
.route("/v1/completions", post(v1_completions));
let public_routes = Router::new()
.route("/liveness", get(liveness))
.route("/readiness", get(readiness))
.route("/health", get(health))
.route("/health_generate", get(health_generate))
.route("/v1/models", get(v1_models))
.route("/get_model_info", get(get_model_info))
.route("/get_server_info", get(get_server_info));
let admin_routes = Router::new()
.route("/add_worker", post(add_worker))
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
// Build app with all routes and middleware
Router::new()
.merge(protected_routes)
.merge(public_routes)
.merge(admin_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size,
))
// Request ID layer - must be added AFTER logging layer in the code
// so it executes BEFORE logging layer at runtime (layers execute bottom-up)
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
// Custom logging layer that can now see request IDs from extensions
.layer(crate::middleware::create_logging_layer())
// CORS (should be outermost)
.layer(create_cors_layer(cors_allowed_origins))
// Fallback
.fallback(sink_handler)
// State - apply last to get Router<Arc<AppState>>
.with_state(app_state)
}
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
// Only initialize logging if not already done (for Python bindings support)
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
@@ -338,14 +273,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections
.timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
.build()
.expect("Failed to create HTTP client");
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let router_arc = Arc::clone(&app_state_init.router);
let app_state = web::Data::new(app_state_init);
let app_state = Arc::new(AppState::new(
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
)?);
let router_arc = Arc::clone(&app_state.router);
// Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config {
@@ -383,36 +324,83 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
]
});
HttpServer::new(move || {
let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());
// Build the application
let app = build_app(
app_state,
config.max_payload_size,
request_id_headers,
config.router_config.cors_allowed_origins.clone(),
);
App::new()
.wrap(request_id_middleware)
.app_data(app_state.clone())
.app_data(
web::JsonConfig::default()
.limit(config.max_payload_size)
.error_handler(json_error_handler),
)
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
.service(v1_models)
.service(get_model_info)
.service(liveness)
.service(readiness)
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
.service(remove_worker)
.service(list_workers)
.service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler))
})
.bind_auto_h2c((config.host, config.port))?
.run()
.await
// Create TCP listener - use the configured host
let addr = format!("{}:{}", config.host, config.port);
let listener = TcpListener::bind(&addr).await?;
// Start server with graceful shutdown
info!("Starting server on {}", addr);
// Serve the application with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
Ok(())
}
// Graceful shutdown handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C, starting graceful shutdown");
},
_ = terminate => {
info!("Received terminate signal, starting graceful shutdown");
},
}
}
// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
use tower_http::cors::Any;
let cors = if allowed_origins.is_empty() {
// Allow all origins if none specified
tower_http::cors::CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
} else {
// Restrict to specific origins
let origins: Vec<http::HeaderValue> = allowed_origins
.into_iter()
.filter_map(|origin| origin.parse().ok())
.collect();
tower_http::cors::CorsLayer::new()
.allow_origin(origins)
.allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
.expose_headers([http::header::HeaderName::from_static("x-request-id")])
};
cors.max_age(Duration::from_secs(3600))
}