[router] migrate router from actix to axum (#8479)

This commit is contained in:
Simo Lin
2025-07-30 17:47:19 -07:00
committed by GitHub
parent 299803343d
commit 66a398f49d
18 changed files with 3626 additions and 3549 deletions

View File

@@ -35,6 +35,10 @@ pub struct RouterConfig {
pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>,
/// Maximum concurrent requests allowed (for rate limiting)
pub max_concurrent_requests: usize,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>,
}
/// Routing mode configuration
@@ -216,6 +220,8 @@ impl Default for RouterConfig {
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}
}
}
@@ -324,6 +330,8 @@ mod tests {
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
let json = serde_json::to_string(&config).unwrap();
@@ -749,6 +757,8 @@ mod tests {
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
assert!(config.mode.is_pd_mode());
@@ -798,6 +808,8 @@ mod tests {
log_dir: None,
log_level: Some("debug".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
assert!(!config.mode.is_pd_mode());
@@ -843,6 +855,8 @@ mod tests {
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
assert!(config.has_service_discovery());

View File

@@ -60,6 +60,9 @@ struct Router {
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
// Additional server config fields
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
}
impl Router {
@@ -145,6 +148,8 @@ impl Router {
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(),
})
}
}
@@ -184,7 +189,9 @@ impl Router {
prefill_urls = None,
decode_urls = None,
prefill_policy = None,
decode_policy = None
decode_policy = None,
max_concurrent_requests = 64,
cors_allowed_origins = vec![]
))]
fn new(
worker_urls: Vec<String>,
@@ -219,6 +226,8 @@ impl Router {
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
) -> PyResult<Self> {
Ok(Router {
host,
@@ -253,6 +262,8 @@ impl Router {
decode_urls,
prefill_policy,
decode_policy,
max_concurrent_requests,
cors_allowed_origins,
})
}

View File

@@ -1,9 +1,9 @@
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, HttpRequest,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};
use axum::{extract::Request, http::HeaderValue, response::Response};
use std::sync::Arc;
use std::time::Instant;
use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{field::Empty, info_span, Span};
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
@@ -31,67 +31,67 @@ fn generate_request_id(path: &str) -> String {
format!("{}{}", prefix, random_part)
}
/// Extract request ID from request extensions or generate a new one
pub fn get_request_id(req: &HttpRequest) -> String {
req.extensions()
.get::<String>()
.cloned()
.unwrap_or_else(|| generate_request_id(req.path()))
/// Extension type for storing request ID
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
/// Tower Layer for request ID middleware
#[derive(Clone)]
pub struct RequestIdLayer {
headers: Arc<Vec<String>>,
}
/// Middleware for injecting request ID into request extensions
pub struct RequestIdMiddleware {
headers: Vec<String>,
}
impl RequestIdMiddleware {
impl RequestIdLayer {
pub fn new(headers: Vec<String>) -> Self {
Self { headers }
Self {
headers: Arc::new(headers),
}
}
}
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = RequestIdMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdMiddleware<S>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RequestIdMiddlewareService {
service,
fn layer(&self, inner: S) -> Self::Service {
RequestIdMiddleware {
inner,
headers: self.headers.clone(),
}))
}
}
}
pub struct RequestIdMiddlewareService<S> {
service: S,
headers: Vec<String>,
/// Tower Service for request ID middleware
#[derive(Clone)]
pub struct RequestIdMiddleware<S> {
inner: S,
headers: Arc<Vec<String>>,
}
impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
impl<S> Service<Request> for RequestIdMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
forward_ready!(service);
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request) -> Self::Future {
let headers = self.headers.clone();
fn call(&self, req: ServiceRequest) -> Self::Future {
// Extract request ID from headers or generate new one
let mut request_id = None;
for header_name in &self.headers {
for header_name in headers.iter() {
if let Some(header_value) = req.headers().get(header_name) {
if let Ok(value) = header_value.to_str() {
request_id = Some(value.to_string());
@@ -100,12 +100,216 @@ where
}
}
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path()));
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path()));
// Insert request ID into request extensions
req.extensions_mut().insert(request_id);
req.extensions_mut().insert(RequestId(request_id.clone()));
let fut = self.service.call(req);
Box::pin(async move { fut.await })
// Create a span with the request ID for this request
let span = tracing::info_span!(
"http_request",
method = %req.method(),
uri = %req.uri(),
version = ?req.version(),
request_id = %request_id
);
// Log within the span
let _enter = span.enter();
tracing::info!(
target: "sglang_router_rs::request",
"started processing request"
);
drop(_enter);
// Capture values we need in the async block
let method = req.method().clone();
let uri = req.uri().clone();
let version = req.version();
// Call the inner service
let future = self.inner.call(req);
Box::pin(async move {
let start_time = Instant::now();
let mut response = future.await?;
let latency = start_time.elapsed();
// Add request ID to response headers
response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id)
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")),
);
// Log the response with proper request ID in span
let status = response.status();
let span = tracing::info_span!(
"http_request",
method = %method,
uri = %uri,
version = ?version,
request_id = %request_id,
status = %status,
latency = ?latency
);
let _enter = span.enter();
if status.is_server_error() {
tracing::error!(
target: "sglang_router_rs::response",
"request failed with server error"
);
} else if status.is_client_error() {
tracing::warn!(
target: "sglang_router_rs::response",
"request failed with client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::response",
"finished processing request"
);
}
Ok(response)
})
}
}
// ============= Logging Middleware =============
/// Custom span maker that includes request ID
#[derive(Clone, Debug)]
pub struct RequestSpan;
impl<B> MakeSpan<B> for RequestSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
// Don't try to extract request ID here - it won't be available yet
// The RequestIdLayer runs after TraceLayer creates the span
info_span!(
"http_request",
method = %request.method(),
uri = %request.uri(),
version = ?request.version(),
request_id = Empty, // Will be set later
status_code = Empty,
latency = Empty,
error = Empty,
)
}
}
/// Custom on_request handler
#[derive(Clone, Debug)]
pub struct RequestLogger;
impl<B> OnRequest<B> for RequestLogger {
fn on_request(&mut self, request: &Request<B>, span: &Span) {
let _enter = span.enter();
// Try to get the request ID from extensions
// This will work if RequestIdLayer has already run
if let Some(request_id) = request.extensions().get::<RequestId>() {
span.record("request_id", &request_id.0.as_str());
}
// Don't log here - we already log in RequestIdService with the proper request_id
}
}
/// Custom on_response handler
#[derive(Clone, Debug)]
pub struct ResponseLogger {
_start_time: Instant,
}
impl Default for ResponseLogger {
fn default() -> Self {
Self {
_start_time: Instant::now(),
}
}
}
impl<B> OnResponse<B> for ResponseLogger {
fn on_response(self, response: &Response<B>, latency: std::time::Duration, span: &Span) {
let status = response.status();
// Record these in the span for structured logging/observability tools
span.record("status_code", status.as_u16());
span.record("latency", format!("{:?}", latency));
// Don't log here - RequestIdService handles all logging with proper request IDs
}
}
/// Create a configured TraceLayer for HTTP logging
/// Note: Actual request/response logging with request IDs is done in RequestIdService
pub fn create_logging_layer() -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
RequestSpan,
RequestLogger,
ResponseLogger,
> {
TraceLayer::new_for_http()
.make_span_with(RequestSpan)
.on_request(RequestLogger)
.on_response(ResponseLogger::default())
}
/// Structured logging data for requests
#[derive(Debug, serde::Serialize)]
pub struct RequestLogEntry {
pub timestamp: String,
pub request_id: String,
pub method: String,
pub uri: String,
pub status: u16,
pub latency_ms: u64,
pub user_agent: Option<String>,
pub remote_addr: Option<String>,
pub error: Option<String>,
}
/// Log a request with structured data
pub fn log_request(entry: RequestLogEntry) {
if entry.status >= 500 {
tracing::error!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
error = ?entry.error,
"HTTP request failed"
);
} else if entry.status >= 400 {
tracing::warn!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request completed"
);
}
}

View File

@@ -1,10 +1,17 @@
//! Router implementations
use actix_web::{HttpRequest, HttpResponse};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use reqwest::Client;
use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod pd_router;
pub mod pd_types;
@@ -33,54 +40,55 @@ pub trait WorkerManagement: Send + Sync {
///
/// This trait provides a unified interface for routing requests,
/// regardless of whether it's a regular router or PD router.
#[async_trait(?Send)]
#[async_trait]
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Get a reference to self as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any;
/// Route a health check request
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
async fn health(&self, client: &Client, req: Request<Body>) -> Response;
/// Route a health generate request
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response;
/// Get server information
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response;
/// Get available models
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response;
/// Get model information
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response;
/// Route a generate request
async fn route_generate(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse;
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response;
/// Route a chat completion request
async fn route_chat(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse;
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response;
/// Route a completion request
async fn route_completion(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse;
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self, client: &Client) -> HttpResponse;
async fn flush_cache(&self, client: &Client) -> Response;
/// Get worker loads (for monitoring)
async fn get_worker_loads(&self, client: &Client) -> HttpResponse;
async fn get_worker_loads(&self, client: &Client) -> Response;
/// Get router type name
fn router_type(&self) -> &'static str;
@@ -91,11 +99,11 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
}
/// Server liveness check - is the server process running
fn liveness(&self) -> HttpResponse {
fn liveness(&self) -> Response {
// Simple liveness check - if we can respond, we're alive
HttpResponse::Ok().body("OK")
(StatusCode::OK, "OK").into_response()
}
/// Server readiness check - is the server ready to handle requests
fn readiness(&self) -> HttpResponse;
fn readiness(&self) -> Response;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,23 @@
use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics;
use crate::middleware::get_request_id;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use futures_util::{StreamExt, TryStreamExt};
use crate::routers::{RouterTrait, WorkerManagement};
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use futures_util::StreamExt;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
@@ -239,154 +245,107 @@ impl Router {
}
}
pub async fn send_request(
&self,
client: &reqwest::Client,
worker_url: &str,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now();
let worker_url = if self.dp_aware {
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
let health_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
match Self::extract_dp_rank(worker_url) {
Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
error!("Failed to extract dp_rank for health check: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to extract dp_rank: {}", e),
)
.into_response();
}
};
worker_url_prefix
}
} else {
worker_url
};
let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
if route != "/health" {
for (name, value) in copy_request_headers(req) {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
}
let request_builder = client.get(format!("{}/health", health_url));
let response = match request_builder.send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Ok(body) => (status, body).into_response(),
Err(e) => {
error!(
request_id = %request_id,
worker_url = %worker_url,
route = %route,
worker_url = %health_url,
error = %e,
"Failed to read response body"
"Failed to read health response body"
);
HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e))
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
}
Err(e) => {
error!(
request_id = %request_id,
worker_url = %worker_url,
route = %route,
worker_url = %health_url,
error = %e,
"Failed to send request to worker"
"Failed to send health request to worker"
);
HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
))
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to send request to worker {}: {}", health_url, e),
)
.into_response()
}
};
// Record request metrics
if route != "/health" {
let duration = start.elapsed();
RouterMetrics::record_request(route);
RouterMetrics::record_request_duration(route, duration);
if !response.status().is_success() {
RouterMetrics::record_request_error(route, "request_failed");
}
}
// Don't record metrics for health checks
response
}
pub async fn route_to_first(
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(
&self,
client: &reqwest::Client,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
let request_id = get_request_id(req);
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
client: &Client,
req: Request<Body>,
endpoint: &str,
) -> Response {
let headers = copy_request_headers(&req);
while total_retries < MAX_TOTAL_RETRIES {
match self.select_first_worker() {
Ok(worker_url) => {
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
}
let response = self.send_request(client, &worker_url, route, req).await;
if response.status().is_success() {
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
return response;
}
}
warn!(
request_id = %request_id,
route = %route,
worker_url = %worker_url,
attempt = request_retries + 1,
max_attempts = MAX_REQUEST_RETRIES,
"Request failed"
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!(
request_id = %request_id,
worker_url = %worker_url,
"Removing failed worker"
);
self.remove_failed_worker(&worker_url);
break;
}
match self.select_first_worker() {
Ok(worker_url) => {
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
for (name, value) in headers {
if name.to_lowercase() != "content-type"
&& name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
Err(e) => return HttpResponse::InternalServerError().body(e),
}
}
HttpResponse::InternalServerError().body("All retry attempts failed")
match request_builder.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => (status, body).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Request failed: {}", e),
)
.into_response(),
}
}
Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
}
}
// New method to route typed requests directly
@@ -395,11 +354,10 @@ impl Router {
>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
) -> HttpResponse {
let request_id = get_request_id(req);
) -> Response {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
@@ -440,7 +398,7 @@ impl Router {
let response = self
.send_typed_request(
client,
req,
headers,
typed_req,
route,
&worker_url,
@@ -455,8 +413,7 @@ impl Router {
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
let health_response = self.send_health_check(client, &worker_url).await;
if health_response.status().is_success() {
RouterMetrics::record_request_error(route, "request_failed");
return response;
@@ -464,9 +421,11 @@ impl Router {
}
warn!(
request_id = %request_id,
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES
route,
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
@@ -474,17 +433,21 @@ impl Router {
if request_retries == MAX_REQUEST_RETRIES {
warn!(
request_id = %request_id,
"Removing failed worker after typed request failures worker_url={}", worker_url
"Removing failed worker after typed request failures worker_url={}",
worker_url
);
self.remove_failed_worker(&worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
RouterMetrics::record_request_error(route, "request_failed");
HttpResponse::InternalServerError().body("All retry attempts failed")
(
StatusCode::INTERNAL_SERVER_ERROR,
"All retry attempts failed",
)
.into_response()
}
// Helper method to select worker from text using the policy
@@ -521,14 +484,13 @@ impl Router {
async fn send_typed_request<T: serde::Serialize>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
worker_url: &str,
is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request
) -> HttpResponse {
let request_id = get_request_id(req);
) -> Response {
let start = Instant::now();
let mut request_builder = if self.dp_aware {
@@ -536,7 +498,11 @@ impl Router {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to extract dp_rank: {}", e),
)
.into_response();
}
};
@@ -544,8 +510,11 @@ impl Router {
let mut json_val = match serde_json::to_value(typed_req) {
Ok(j) => j,
Err(e) => {
return HttpResponse::BadRequest()
.body(format!("Convert into serde_json::Value failed: {}", e));
return (
StatusCode::BAD_REQUEST,
format!("Convert into serde_json::Value failed: {}", e),
)
.into_response();
}
};
@@ -560,8 +529,11 @@ impl Router {
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
);
} else {
return HttpResponse::BadRequest()
.body("Failed to insert the data_parallel_rank field into the request body");
return (
StatusCode::BAD_REQUEST,
"Failed to insert the data_parallel_rank field into the request body",
)
.into_response();
}
client
@@ -573,11 +545,15 @@ impl Router {
.json(typed_req) // Use json() directly with typed request
};
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
request_builder = request_builder.header(&name, &value);
// Copy all headers from original request if provided
if let Some(headers) = headers {
for (name, value) in headers {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_string().to_lowercase() != "content-type"
&& name.to_string().to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
}
@@ -585,7 +561,6 @@ impl Router {
Ok(res) => res,
Err(e) => {
error!(
request_id = %request_id,
"Failed to send typed request worker_url={} route={} error={}",
worker_url, route, e
);
@@ -600,20 +575,24 @@ impl Router {
}
}
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Request failed: {}", e),
)
.into_response();
}
};
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Ok(body) => (status, body).into_response(),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
}
};
@@ -638,42 +617,86 @@ impl Router {
let workers = Arc::clone(&self.workers);
let worker_url = worker_url.to_string();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
if let Ok(bytes) = bytes {
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(
&worker_url,
worker.load(),
);
}
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
// Spawn task to forward stream and detect completion
tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
// Check for stream end marker
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(
&worker_url,
worker.load(),
);
}
}
}
}),
)
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
} else {
// For requests without load tracking, just stream
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
}))
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
// Spawn task to forward stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
@@ -775,7 +798,6 @@ impl Router {
}
}
/// Remove all the worker(s) that match the URL prefix
pub fn remove_worker(&self, worker_url: &str) {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
@@ -844,28 +866,6 @@ impl Router {
}
}
/// Remove a specific failed worker; for internal usage
fn remove_failed_worker(&self, worker_url: &str) {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed failed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
// If cache aware policy, remove the worker from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
@@ -1004,7 +1004,6 @@ impl Router {
}
}
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use reqwest::Client;
@@ -1023,100 +1022,78 @@ impl WorkerManagement for Router {
}
}
#[async_trait(?Send)]
#[async_trait]
impl RouterTrait for Router {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse {
// Check local health state of all workers (consistent with PD router)
// Note: This uses cached health status from background health checks, not live checks
let mut all_healthy = true;
let mut unhealthy_servers = Vec::new();
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
let workers = self.workers.read().unwrap();
let unhealthy_servers: Vec<_> = workers
.iter()
.filter(|w| !w.is_healthy())
.map(|w| w.url().to_string())
.collect();
for worker in self.workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(worker.url().to_string());
}
}
if all_healthy {
HttpResponse::Ok().body("All servers healthy")
if unhealthy_servers.is_empty() {
(StatusCode::OK, "All servers healthy").into_response()
} else {
HttpResponse::ServiceUnavailable()
.body(format!("Unhealthy servers: {:?}", unhealthy_servers))
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Unhealthy servers: {:?}", unhealthy_servers),
)
.into_response()
}
}
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
// Test model generation capability by sending to first available worker
// Note: This endpoint actually causes the model to generate a token, so we only test one worker
self.route_to_first(client, "/health_generate", req).await
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "health_generate").await
}
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
self.route_to_first(client, "/get_server_info", req).await
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_server_info").await
}
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
self.route_to_first(client, "/v1/models", req).await
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "v1/models").await
}
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
self.route_to_first(client, "/get_model_info", req).await
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_model_info").await
}
async fn route_generate(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
// Convert JSON to typed request
match serde_json::from_value::<crate::openai_api_types::GenerateRequest>(body) {
Ok(typed_req) => {
self.route_typed_request(client, req, &typed_req, "/generate")
.await
}
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
}
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/generate")
.await
}
async fn route_chat(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
// Convert JSON to typed request
match serde_json::from_value::<crate::openai_api_types::ChatCompletionRequest>(body) {
Ok(typed_req) => {
self.route_typed_request(client, req, &typed_req, "/v1/chat/completions")
.await
}
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
}
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/v1/chat/completions")
.await
}
async fn route_completion(
&self,
client: &Client,
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
// Convert JSON to typed request
match serde_json::from_value::<crate::openai_api_types::CompletionRequest>(body) {
Ok(typed_req) => {
self.route_typed_request(client, req, &typed_req, "/v1/completions")
.await
}
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
}
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/v1/completions")
.await
}
async fn flush_cache(&self, client: &Client) -> HttpResponse {
async fn flush_cache(&self, client: &Client) -> Response {
// Get all worker URLs
let worker_urls = self.get_worker_urls();
@@ -1129,7 +1106,11 @@ impl RouterTrait for Router {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to extract dp_rank: {}", e),
)
.into_response();
}
};
worker_url_prefix
@@ -1151,13 +1132,17 @@ impl RouterTrait for Router {
});
if all_success {
HttpResponse::Ok().body("Cache flushed on all servers")
(StatusCode::OK, "Cache flushed on all servers").into_response()
} else {
HttpResponse::InternalServerError().body("Cache flush failed on one or more servers")
(
StatusCode::INTERNAL_SERVER_ERROR,
"Cache flush failed on one or more servers",
)
.into_response()
}
}
async fn get_worker_loads(&self, client: &Client) -> HttpResponse {
async fn get_worker_loads(&self, client: &Client) -> Response {
let urls = self.get_worker_urls();
let mut loads = Vec::new();
@@ -1170,16 +1155,17 @@ impl RouterTrait for Router {
}));
}
HttpResponse::Ok().json(serde_json::json!({
Json(serde_json::json!({
"workers": loads
}))
.into_response()
}
fn router_type(&self) -> &'static str {
"regular"
}
fn readiness(&self) -> HttpResponse {
fn readiness(&self) -> Response {
// Regular router is ready if it has at least one healthy worker
let healthy_count = self
.workers
@@ -1190,17 +1176,22 @@ impl RouterTrait for Router {
.count();
if healthy_count > 0 {
HttpResponse::Ok().json(serde_json::json!({
Json(serde_json::json!({
"status": "ready",
"healthy_workers": healthy_count,
"total_workers": self.workers.read().unwrap().len()
}))
.into_response()
} else {
HttpResponse::ServiceUnavailable().json(serde_json::json!({
"status": "not_ready",
"reason": "no healthy workers available",
"total_workers": self.workers.read().unwrap().len()
}))
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"status": "not_ready",
"reason": "no healthy workers available",
"total_workers": self.workers.read().unwrap().len()
})),
)
.into_response()
}
}
}

View File

@@ -1,285 +1,169 @@
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::{get_request_id, RequestIdMiddleware};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
use axum::{
extract::{Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::signal;
use tokio::spawn;
use tracing::{error, info, warn, Level};
#[derive(Debug)]
#[derive(Clone)]
pub struct AppState {
router: Arc<dyn RouterTrait>,
client: Client,
pub router: Arc<dyn RouterTrait>,
pub client: Client,
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
}
impl AppState {
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
// Use RouterFactory to create the appropriate router type
pub fn new(
router_config: RouterConfig,
client: Client,
max_concurrent_requests: usize,
) -> Result<Self, String> {
let router = RouterFactory::create_router(&router_config)?;
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
let router = Arc::from(router);
Ok(Self { router, client })
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
Ok(Self {
router,
client,
_concurrency_limiter: concurrency_limiter,
})
}
}
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
// Drain the payload
while let Some(chunk) = payload.next().await {
if let Err(err) = chunk {
println!("Error while draining payload: {:?}", err);
break;
}
}
Ok(HttpResponse::NotFound().finish())
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response()
}
// Custom error handler for JSON payload errors.
fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
let request_id = get_request_id(req);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
request_id = %request_id,
"Payload too large length={} limit={}", length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!(
request_id = %request_id,
"Payload overflow limit={}", limit
);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => {
error!(
request_id = %request_id,
"Invalid JSON payload error={}", err
);
error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
}
}
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
state.router.liveness()
}
#[get("/liveness")]
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.liveness()
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
state.router.readiness()
}
#[get("/readiness")]
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.readiness()
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health(&state.client, req).await
}
#[get("/health")]
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.health(&data.client, &req).await
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health_generate(&state.client, req).await
}
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.health_generate(&data.client, &req).await
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_server_info(&state.client, req).await
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_server_info(&data.client, &req).await
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_models(&state.client, req).await
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_models(&data.client, &req).await
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_model_info(&state.client, req).await
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_model_info(&data.client, &req).await
}
#[post("/generate")]
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
async fn generate(
req: HttpRequest,
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received generate request method=\"POST\" path=\"/generate\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse generate request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<GenerateRequest>,
) -> Response {
state
.router
.route_generate(&state.client, &req, json_body)
.await)
.route_generate(&state.client, Some(&headers), &body)
.await
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse chat completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
) -> Response {
state
.router
.route_chat(&state.client, &req, json_body)
.await)
.route_chat(&state.client, Some(&headers), &body)
.await
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let request_id = get_request_id(&req);
info!(
request_id = %request_id,
"Received completion request method=\"POST\" path=\"/v1/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<CompletionRequest>,
) -> Response {
state
.router
.route_completion(&state.client, &req, json_body)
.await)
.route_completion(&state.client, Some(&headers), &body)
.await
}
#[post("/add_worker")]
// Worker management endpoints
async fn add_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") {
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => {
warn!(
request_id = %request_id,
"Add worker request missing URL parameter"
);
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter");
return (
StatusCode::BAD_REQUEST,
"Worker URL required. Provide 'url' query parameter",
)
.into_response();
}
};
info!(
request_id = %request_id,
worker_url = %worker_url,
"Adding worker"
);
match data.router.add_worker(&worker_url).await {
Ok(message) => {
info!(
request_id = %request_id,
worker_url = %worker_url,
"Successfully added worker"
);
HttpResponse::Ok().body(message)
}
Err(error) => {
error!(
request_id = %request_id,
worker_url = %worker_url,
error = %error,
"Failed to add worker"
);
HttpResponse::BadRequest().body(error)
}
match state.router.add_worker(&worker_url).await {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
let worker_list = data.router.get_worker_urls();
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
let worker_list = state.router.get_worker_urls();
Json(serde_json::json!({ "urls": worker_list })).into_response()
}
#[post("/remove_worker")]
async fn remove_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") {
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => {
warn!(
request_id = %request_id,
"Remove worker request missing URL parameter"
);
return HttpResponse::BadRequest().finish();
}
None => return StatusCode::BAD_REQUEST.into_response(),
};
info!(
request_id = %request_id,
worker_url = %worker_url,
"Removing worker"
);
data.router.remove_worker(&worker_url);
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
state.router.remove_worker(&worker_url);
(
StatusCode::OK,
format!("Successfully removed worker: {}", worker_url),
)
.into_response()
}
#[post("/flush_cache")]
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.flush_cache(&data.client).await
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.flush_cache(&state.client).await
}
#[get("/get_loads")]
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_worker_loads(&data.client).await
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads(&state.client).await
}
pub struct ServerConfig {
@@ -295,7 +179,58 @@ pub struct ServerConfig {
pub request_id_headers: Option<Vec<String>>,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
/// Build the Axum application with all routes and middleware
pub fn build_app(
app_state: Arc<AppState>,
max_payload_size: usize,
request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>,
) -> Router {
// Create routes
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
.route("/v1/completions", post(v1_completions));
let public_routes = Router::new()
.route("/liveness", get(liveness))
.route("/readiness", get(readiness))
.route("/health", get(health))
.route("/health_generate", get(health_generate))
.route("/v1/models", get(v1_models))
.route("/get_model_info", get(get_model_info))
.route("/get_server_info", get(get_server_info));
let admin_routes = Router::new()
.route("/add_worker", post(add_worker))
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
// Build app with all routes and middleware
Router::new()
.merge(protected_routes)
.merge(public_routes)
.merge(admin_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size,
))
// Request ID layer - must be added AFTER logging layer in the code
// so it executes BEFORE logging layer at runtime (layers execute bottom-up)
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
// Custom logging layer that can now see request IDs from extensions
.layer(crate::middleware::create_logging_layer())
// CORS (should be outermost)
.layer(create_cors_layer(cors_allowed_origins))
// Fallback
.fallback(sink_handler)
// State - apply last to get Router<Arc<AppState>>
.with_state(app_state)
}
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
// Only initialize logging if not already done (for Python bindings support)
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
@@ -338,14 +273,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections
.timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
.build()
.expect("Failed to create HTTP client");
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let router_arc = Arc::clone(&app_state_init.router);
let app_state = web::Data::new(app_state_init);
let app_state = Arc::new(AppState::new(
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
)?);
let router_arc = Arc::clone(&app_state.router);
// Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config {
@@ -383,36 +324,83 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
]
});
HttpServer::new(move || {
let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());
// Build the application
let app = build_app(
app_state,
config.max_payload_size,
request_id_headers,
config.router_config.cors_allowed_origins.clone(),
);
App::new()
.wrap(request_id_middleware)
.app_data(app_state.clone())
.app_data(
web::JsonConfig::default()
.limit(config.max_payload_size)
.error_handler(json_error_handler),
)
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
.service(v1_models)
.service(get_model_info)
.service(liveness)
.service(readiness)
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
.service(remove_worker)
.service(list_workers)
.service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler))
})
.bind_auto_h2c((config.host, config.port))?
.run()
.await
// Create TCP listener - use the configured host
let addr = format!("{}:{}", config.host, config.port);
let listener = TcpListener::bind(&addr).await?;
// Start server with graceful shutdown
info!("Starting server on {}", addr);
// Serve the application with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
Ok(())
}
// Graceful shutdown handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C, starting graceful shutdown");
},
_ = terminate => {
info!("Received terminate signal, starting graceful shutdown");
},
}
}
// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
use tower_http::cors::Any;
let cors = if allowed_origins.is_empty() {
// Allow all origins if none specified
tower_http::cors::CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
} else {
// Restrict to specific origins
let origins: Vec<http::HeaderValue> = allowed_origins
.into_iter()
.filter_map(|origin| origin.parse().ok())
.collect();
tower_http::cors::CorsLayer::new()
.allow_origin(origins)
.allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
.expose_headers([http::header::HeaderName::from_static("x-request-id")])
};
cors.max_age(Duration::from_secs(3600))
}