diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index b97974367..8dc40527a 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -1,18 +1,20 @@ //! Factory for creating router instances use super::{pd_router::PDRouter, router::Router, RouterTrait}; -use crate::config::{PolicyConfig, RouterConfig, RoutingMode}; +use crate::config::{PolicyConfig, RoutingMode}; use crate::policies::PolicyFactory; +use crate::server::AppContext; +use std::sync::Arc; /// Factory for creating router instances based on configuration pub struct RouterFactory; impl RouterFactory { - /// Create a router instance from configuration - pub fn create_router(config: &RouterConfig) -> Result, String> { - match &config.mode { + /// Create a router instance from application context + pub fn create_router(ctx: &Arc) -> Result, String> { + match &ctx.router_config.mode { RoutingMode::Regular { worker_urls } => { - Self::create_regular_router(worker_urls, &config.policy, config) + Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx) } RoutingMode::PrefillDecode { prefill_urls, @@ -24,8 +26,8 @@ impl RouterFactory { decode_urls, prefill_policy.as_ref(), decode_policy.as_ref(), - &config.policy, - config, + &ctx.router_config.policy, + ctx, ), } } @@ -34,19 +36,20 @@ impl RouterFactory { fn create_regular_router( worker_urls: &[String], policy_config: &PolicyConfig, - router_config: &RouterConfig, + ctx: &Arc, ) -> Result, String> { // Create policy let policy = PolicyFactory::create_from_config(policy_config); - // Create regular router with injected policy + // Create regular router with injected policy and client let router = Router::new( worker_urls.to_vec(), policy, - router_config.worker_startup_timeout_secs, - router_config.worker_startup_check_interval_secs, - router_config.dp_aware, - router_config.api_key.clone(), + ctx.client.clone(), + ctx.router_config.worker_startup_timeout_secs, + ctx.router_config.worker_startup_check_interval_secs, + ctx.router_config.dp_aware, + ctx.router_config.api_key.clone(), )?; Ok(Box::new(router)) @@ -59,7 +62,7 @@ impl RouterFactory { prefill_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>, main_policy_config: &PolicyConfig, - router_config: &RouterConfig, + ctx: &Arc, ) -> Result, String> { // Create policies - use specific policies if provided, otherwise fall back to main policy let prefill_policy = @@ -67,14 +70,15 @@ impl RouterFactory { let decode_policy = PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); - // Create PD router with separate policies + // Create PD router with separate policies and client let router = PDRouter::new( prefill_urls.to_vec(), decode_urls.to_vec(), prefill_policy, decode_policy, - router_config.worker_startup_timeout_secs, - router_config.worker_startup_check_interval_secs, + ctx.client.clone(), + ctx.router_config.worker_startup_timeout_secs, + ctx.router_config.worker_startup_check_interval_secs, )?; Ok(Box::new(router)) diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 21250d5f1..75f12c63b 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -7,7 +7,6 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; -use reqwest::Client; use std::fmt::Debug; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; @@ -46,32 +45,27 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { fn as_any(&self) -> &dyn std::any::Any; /// Route a health check request - async fn health(&self, client: &Client, req: Request) -> Response; + async fn health(&self, req: Request) -> Response; /// Route a health generate request - async fn health_generate(&self, client: &Client, req: Request) -> Response; + async fn health_generate(&self, req: Request) -> Response; /// Get server information - async fn get_server_info(&self, client: &Client, req: Request) -> Response; + async fn get_server_info(&self, req: Request) -> Response; /// Get available models - async fn get_models(&self, client: &Client, req: Request) -> Response; + async fn get_models(&self, req: Request) -> Response; /// Get model information - async fn get_model_info(&self, client: &Client, req: Request) -> Response; + async fn get_model_info(&self, req: Request) -> Response; /// Route a generate request - async fn route_generate( - &self, - client: &Client, - headers: Option<&HeaderMap>, - body: &GenerateRequest, - ) -> Response; + async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest) + -> Response; /// Route a chat completion request async fn route_chat( &self, - client: &Client, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response; @@ -79,16 +73,15 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { /// Route a completion request async fn route_completion( &self, - client: &Client, headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response; /// Flush cache on all workers - async fn flush_cache(&self, client: &Client) -> Response; + async fn flush_cache(&self) -> Response; /// Get worker loads (for monitoring) - async fn get_worker_loads(&self, client: &Client) -> Response; + async fn get_worker_loads(&self) -> Response; /// Get router type name fn router_type(&self) -> &'static str; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index d9cbf9bac..b799237a9 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -35,7 +35,7 @@ pub struct PDRouter { pub interval_secs: u64, pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, - pub http_client: Client, + pub client: Client, _prefill_health_checker: Option, _decode_health_checker: Option, } @@ -177,6 +177,7 @@ impl PDRouter { decode_urls: Vec, prefill_policy: Arc, decode_policy: Arc, + client: Client, timeout_secs: u64, interval_secs: u64, ) -> Result { @@ -215,17 +216,11 @@ impl PDRouter { let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); - // Create a shared HTTP client for all operations - let http_client = Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - let load_monitor_handle = if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { let monitor_urls = all_urls.clone(); let monitor_interval = interval_secs; - let monitor_client = http_client.clone(); + let monitor_client = client.clone(); let prefill_policy_clone = Arc::clone(&prefill_policy); let decode_policy_clone = Arc::clone(&decode_policy); @@ -264,7 +259,7 @@ impl PDRouter { interval_secs, worker_loads, load_monitor_handle, - http_client, + client, _prefill_health_checker: Some(prefill_health_checker), _decode_health_checker: Some(decode_health_checker), }) @@ -302,7 +297,6 @@ impl PDRouter { // Route a typed generate request pub async fn route_generate( &self, - client: &Client, headers: Option<&HeaderMap>, mut typed_req: GenerateReqInput, route: &str, @@ -371,7 +365,6 @@ impl PDRouter { // Execute dual dispatch self.execute_dual_dispatch( - client, headers, json_with_bootstrap, route, @@ -387,7 +380,6 @@ impl PDRouter { // Route a typed chat request pub async fn route_chat( &self, - client: &Client, headers: Option<&HeaderMap>, mut typed_req: ChatReqInput, route: &str, @@ -459,7 +451,6 @@ impl PDRouter { // Execute dual dispatch self.execute_dual_dispatch( - client, headers, json_with_bootstrap, route, @@ -475,7 +466,6 @@ impl PDRouter { // Route a completion request while preserving OpenAI format pub async fn route_completion( &self, - client: &Client, headers: Option<&HeaderMap>, mut typed_req: CompletionRequest, route: &str, @@ -540,7 +530,6 @@ impl PDRouter { // Execute dual dispatch self.execute_dual_dispatch( - client, headers, json_with_bootstrap, route, @@ -554,10 +543,8 @@ impl PDRouter { } // Execute the dual dispatch to prefill and decode servers - #[allow(clippy::too_many_arguments)] async fn execute_dual_dispatch( &self, - client: &Client, headers: Option<&HeaderMap>, json_request: Value, route: &str, @@ -571,11 +558,13 @@ impl PDRouter { let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); // Build requests using .json() method - let mut prefill_request = client + let mut prefill_request = self + .client .post(api_path(prefill.url(), route)) .json(&json_request); - let mut decode_request = client + let mut decode_request = self + .client .post(api_path(decode.url(), route)) .json(&json_request); @@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option Response { + pub async fn health_generate(&self) -> Response { // Test model generation capability by selecting a random pair and testing them // Note: This endpoint actually causes the model to generate tokens, so we only test one pair @@ -1005,11 +994,11 @@ impl PDRouter { // Test prefill server's health_generate let prefill_url = format!("{}/health_generate", prefill.url()); - let prefill_result = client.get(&prefill_url).send().await; + let prefill_result = self.client.get(&prefill_url).send().await; // Test decode server's health_generate let decode_url = format!("{}/health_generate", decode.url()); - let decode_result = client.get(&decode_url).send().await; + let decode_result = self.client.get(&decode_url).send().await; // Check results let mut errors = Vec::new(); @@ -1068,7 +1057,7 @@ impl PDRouter { } } - pub async fn get_server_info(&self, client: &reqwest::Client) -> Response { + pub async fn get_server_info(&self) -> Response { // Get info from the first decode server to match sglang's server info format let first_decode_url = if let Ok(workers) = self.decode_workers.read() { workers.first().map(|w| w.url().to_string()) @@ -1081,7 +1070,8 @@ impl PDRouter { }; if let Some(worker_url) = first_decode_url { - match client + match self + .client .get(format!("{}/get_server_info", worker_url)) .send() .await @@ -1130,7 +1120,7 @@ impl PDRouter { } } - pub async fn get_models(&self, client: &reqwest::Client, req: Request) -> Response { + pub async fn get_models(&self, req: Request) -> Response { // Extract headers first to avoid Send issues let headers = crate::routers::router::copy_request_headers(&req); @@ -1147,7 +1137,7 @@ impl PDRouter { if let Some(worker_url) = first_worker_url { // Send request directly without going through Router - let mut request_builder = client.get(format!("{}/v1/models", worker_url)); + let mut request_builder = self.client.get(format!("{}/v1/models", worker_url)); for (name, value) in headers { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { @@ -1224,7 +1214,7 @@ impl PDRouter { .into_response() } - pub async fn get_model_info(&self, client: &reqwest::Client, req: Request) -> Response { + pub async fn get_model_info(&self, req: Request) -> Response { // Extract headers first to avoid Send issues let headers = crate::routers::router::copy_request_headers(&req); @@ -1241,7 +1231,7 @@ impl PDRouter { }; if let Some(worker_url) = first_worker_url { - let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); + let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url)); for (name, value) in headers { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { @@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter { self } - async fn health(&self, _client: &Client, _req: Request) -> Response { + async fn health(&self, _req: Request) -> Response { // This is a server readiness check - checking if we have healthy workers // Workers handle their own health checks in the background let mut all_healthy = true; @@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter { } } - async fn health_generate(&self, client: &Client, _req: Request) -> Response { + async fn health_generate(&self, _req: Request) -> Response { // Use the existing PDRouter health_generate method - PDRouter::health_generate(self, client).await + PDRouter::health_generate(self).await } - async fn get_server_info(&self, client: &Client, _req: Request) -> Response { + async fn get_server_info(&self, _req: Request) -> Response { // Use the existing PDRouter get_server_info method - PDRouter::get_server_info(self, client).await + PDRouter::get_server_info(self).await } - async fn get_models(&self, client: &Client, req: Request) -> Response { + async fn get_models(&self, req: Request) -> Response { // Use the existing PDRouter get_models method - PDRouter::get_models(self, client, req).await + PDRouter::get_models(self, req).await } - async fn get_model_info(&self, client: &Client, req: Request) -> Response { + async fn get_model_info(&self, req: Request) -> Response { // Use the existing PDRouter get_model_info method - PDRouter::get_model_info(self, client, req).await + PDRouter::get_model_info(self, req).await } async fn route_generate( &self, - client: &Client, headers: Option<&HeaderMap>, body: &GenerateRequest, ) -> Response { // Convert OpenAI format to PD format let pd_req = body.clone().to_pd_request(); - PDRouter::route_generate(self, client, headers, pd_req, "/generate").await + PDRouter::route_generate(self, headers, pd_req, "/generate").await } async fn route_chat( &self, - client: &Client, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response { // Convert OpenAI format to PD format let pd_req = body.clone().to_pd_request(); - PDRouter::route_chat(self, client, headers, pd_req, "/v1/chat/completions").await + PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await } async fn route_completion( &self, - client: &Client, headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response { // Use the new method that preserves OpenAI format - PDRouter::route_completion(self, client, headers, body.clone(), "/v1/completions").await + PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await } - async fn flush_cache(&self, client: &Client) -> Response { + async fn flush_cache(&self) -> Response { // Use the existing PDRouter flush_cache method - PDRouter::flush_cache(self, client).await + PDRouter::flush_cache(self, &self.client).await } - async fn get_worker_loads(&self, client: &Client) -> Response { + async fn get_worker_loads(&self) -> Response { // Use the existing PDRouter get_loads method - PDRouter::get_loads(self, client).await + PDRouter::get_loads(self, &self.client).await } fn router_type(&self) -> &'static str { @@ -1570,7 +1557,7 @@ mod tests { interval_secs: 1, worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), load_monitor_handle: None, - http_client: reqwest::Client::new(), + client: Client::new(), _prefill_health_checker: None, _decode_health_checker: None, } @@ -1959,11 +1946,10 @@ mod tests { router.decode_workers.write().unwrap().push(decode_worker); // Test health endpoint - let client = reqwest::Client::new(); let http_req = axum::http::Request::builder() .body(axum::body::Body::empty()) .unwrap(); - let response = router.health(&client, http_req).await; + let response = router.health(http_req).await; assert_eq!(response.status(), 200); diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 41277c17e..1a6ddeea4 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request) -> Vec<(String, String)> { pub struct Router { workers: Arc>>>, policy: Arc, + client: Client, timeout_secs: u64, interval_secs: u64, dp_aware: bool, @@ -44,10 +45,11 @@ pub struct Router { } impl Router { - /// Create a new router with injected policy + /// Create a new router with injected policy and client pub fn new( worker_urls: Vec, policy: Arc, + client: Client, timeout_secs: u64, interval_secs: u64, dp_aware: bool, @@ -94,9 +96,17 @@ impl Router { let monitor_urls = worker_urls.clone(); let monitor_interval = interval_secs; let policy_clone = Arc::clone(&policy); + let client_clone = client.clone(); Some(Arc::new(tokio::spawn(async move { - Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await; + Self::monitor_worker_loads( + monitor_urls, + tx, + monitor_interval, + policy_clone, + client_clone, + ) + .await; }))) } else { None @@ -105,6 +115,7 @@ impl Router { Ok(Router { workers, policy, + client, timeout_secs, interval_secs, dp_aware, @@ -245,7 +256,7 @@ impl Router { } } - pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response { + pub async fn send_health_check(&self, worker_url: &str) -> Response { let health_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" match Self::extract_dp_rank(worker_url) { @@ -263,7 +274,7 @@ impl Router { worker_url }; - let request_builder = client.get(format!("{}/health", health_url)); + let request_builder = self.client.get(format!("{}/health", health_url)); let response = match request_builder.send().await { Ok(res) => { @@ -305,17 +316,12 @@ impl Router { } // Helper method to proxy GET requests to the first available worker - async fn proxy_get_request( - &self, - client: &Client, - req: Request, - endpoint: &str, - ) -> Response { + async fn proxy_get_request(&self, req: Request, endpoint: &str) -> Response { let headers = copy_request_headers(&req); match self.select_first_worker() { Ok(worker_url) => { - let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint)); + let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); for (name, value) in headers { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" @@ -353,7 +359,6 @@ impl Router { T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, >( &self, - client: &reqwest::Client, headers: Option<&HeaderMap>, typed_req: &T, route: &str, @@ -397,7 +402,6 @@ impl Router { // Send typed request directly let response = self .send_typed_request( - client, headers, typed_req, route, @@ -413,7 +417,7 @@ impl Router { return response; } else { // if the worker is healthy, it means the request is bad, so return the error response - let health_response = self.send_health_check(client, &worker_url).await; + let health_response = self.send_health_check(&worker_url).await; if health_response.status().is_success() { RouterMetrics::record_request_error(route, "request_failed"); return response; @@ -483,7 +487,6 @@ impl Router { // Send typed request directly without conversion async fn send_typed_request( &self, - client: &reqwest::Client, headers: Option<&HeaderMap>, typed_req: &T, route: &str, @@ -536,11 +539,11 @@ impl Router { .into_response(); } - client + self.client .post(format!("{}{}", worker_url_prefix, route)) .json(&json_val) } else { - client + self.client .post(format!("{}{}", worker_url, route)) .json(typed_req) // Use json() directly with typed request }; @@ -866,7 +869,7 @@ impl Router { } } - async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { + async fn get_worker_load(&self, worker_url: &str) -> Option { let worker_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { @@ -881,7 +884,12 @@ impl Router { worker_url }; - match client.get(&format!("{}/get_load", worker_url)).send().await { + match self + .client + .get(&format!("{}/get_load", worker_url)) + .send() + .await + { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { Ok(data) => data @@ -919,18 +927,8 @@ impl Router { tx: tokio::sync::watch::Sender>, interval_secs: u64, policy: Arc, + client: Client, ) { - let client = match reqwest::Client::builder() - .timeout(Duration::from_secs(5)) - .build() - { - Ok(c) => c, - Err(e) => { - error!("Failed to create HTTP client for load monitoring: {}", e); - return; - } - }; - let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); loop { @@ -1028,7 +1026,7 @@ impl RouterTrait for Router { self } - async fn health(&self, _client: &Client, _req: Request) -> Response { + async fn health(&self, _req: Request) -> Response { let workers = self.workers.read().unwrap(); let unhealthy_servers: Vec<_> = workers .iter() @@ -1047,53 +1045,49 @@ impl RouterTrait for Router { } } - async fn health_generate(&self, client: &Client, req: Request) -> Response { - self.proxy_get_request(client, req, "health_generate").await + async fn health_generate(&self, req: Request) -> Response { + self.proxy_get_request(req, "health_generate").await } - async fn get_server_info(&self, client: &Client, req: Request) -> Response { - self.proxy_get_request(client, req, "get_server_info").await + async fn get_server_info(&self, req: Request) -> Response { + self.proxy_get_request(req, "get_server_info").await } - async fn get_models(&self, client: &Client, req: Request) -> Response { - self.proxy_get_request(client, req, "v1/models").await + async fn get_models(&self, req: Request) -> Response { + self.proxy_get_request(req, "v1/models").await } - async fn get_model_info(&self, client: &Client, req: Request) -> Response { - self.proxy_get_request(client, req, "get_model_info").await + async fn get_model_info(&self, req: Request) -> Response { + self.proxy_get_request(req, "get_model_info").await } async fn route_generate( &self, - client: &Client, headers: Option<&HeaderMap>, body: &GenerateRequest, ) -> Response { - self.route_typed_request(client, headers, body, "/generate") - .await + self.route_typed_request(headers, body, "/generate").await } async fn route_chat( &self, - client: &Client, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response { - self.route_typed_request(client, headers, body, "/v1/chat/completions") + self.route_typed_request(headers, body, "/v1/chat/completions") .await } async fn route_completion( &self, - client: &Client, headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response { - self.route_typed_request(client, headers, body, "/v1/completions") + self.route_typed_request(headers, body, "/v1/completions") .await } - async fn flush_cache(&self, client: &Client) -> Response { + async fn flush_cache(&self) -> Response { // Get all worker URLs let worker_urls = self.get_worker_urls(); @@ -1117,7 +1111,7 @@ impl RouterTrait for Router { } else { worker_url }; - let request_builder = client.post(format!("{}/flush_cache", worker_url)); + let request_builder = self.client.post(format!("{}/flush_cache", worker_url)); tasks.push(request_builder.send()); } @@ -1142,13 +1136,13 @@ impl RouterTrait for Router { } } - async fn get_worker_loads(&self, client: &Client) -> Response { + async fn get_worker_loads(&self) -> Response { let urls = self.get_worker_urls(); let mut loads = Vec::new(); // Get loads from all workers for url in &urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); + let load = self.get_worker_load(url).await.unwrap_or(-1); loads.push(serde_json::json!({ "worker": url, "load": load @@ -1215,6 +1209,7 @@ mod tests { interval_secs: 1, dp_aware: false, api_key: None, + client: Client::new(), _worker_loads: Arc::new(rx), _load_monitor_handle: None, _health_checker: None, diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 0463f1f2a..b6027e70b 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -22,29 +22,34 @@ use tokio::spawn; use tracing::{error, info, warn, Level}; #[derive(Clone)] -pub struct AppState { - pub router: Arc, +pub struct AppContext { pub client: Client, - pub _concurrency_limiter: Arc, + pub router_config: RouterConfig, + pub concurrency_limiter: Arc, + // Future dependencies can be added here } -impl AppState { +impl AppContext { pub fn new( router_config: RouterConfig, client: Client, max_concurrent_requests: usize, - ) -> Result { - let router = RouterFactory::create_router(&router_config)?; - let router = Arc::from(router); + ) -> Self { let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests)); - Ok(Self { - router, + Self { client, - _concurrency_limiter: concurrency_limiter, - }) + router_config, + concurrency_limiter, + } } } +#[derive(Clone)] +pub struct AppState { + pub router: Arc, + pub context: Arc, +} + // Fallback handler for unmatched routes async fn sink_handler() -> Response { StatusCode::NOT_FOUND.into_response() @@ -60,23 +65,23 @@ async fn readiness(State(state): State>) -> Response { } async fn health(State(state): State>, req: Request) -> Response { - state.router.health(&state.client, req).await + state.router.health(req).await } async fn health_generate(State(state): State>, req: Request) -> Response { - state.router.health_generate(&state.client, req).await + state.router.health_generate(req).await } async fn get_server_info(State(state): State>, req: Request) -> Response { - state.router.get_server_info(&state.client, req).await + state.router.get_server_info(req).await } async fn v1_models(State(state): State>, req: Request) -> Response { - state.router.get_models(&state.client, req).await + state.router.get_models(req).await } async fn get_model_info(State(state): State>, req: Request) -> Response { - state.router.get_model_info(&state.client, req).await + state.router.get_model_info(req).await } // Generation endpoints @@ -86,10 +91,7 @@ async fn generate( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state - .router - .route_generate(&state.client, Some(&headers), &body) - .await + state.router.route_generate(Some(&headers), &body).await } async fn v1_chat_completions( @@ -97,10 +99,7 @@ async fn v1_chat_completions( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state - .router - .route_chat(&state.client, Some(&headers), &body) - .await + state.router.route_chat(Some(&headers), &body).await } async fn v1_completions( @@ -108,10 +107,7 @@ async fn v1_completions( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state - .router - .route_completion(&state.client, Some(&headers), &body) - .await + state.router.route_completion(Some(&headers), &body).await } // Worker management endpoints @@ -159,11 +155,11 @@ async fn remove_worker( } async fn flush_cache(State(state): State>, _req: Request) -> Response { - state.router.flush_cache(&state.client).await + state.router.flush_cache().await } async fn get_loads(State(state): State>, _req: Request) -> Response { - state.router.get_worker_loads(&state.client).await + state.router.get_worker_loads().await } pub struct ServerConfig { @@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 2626174ce..6beda2b7a 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -83,12 +83,12 @@ impl TestContext { .build() .unwrap(); - // Clone config for the closure - let config_clone = config.clone(); + // Create app context + let app_context = common::create_test_context(config.clone()); // Create router using sync factory in a blocking context let router = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&config_clone)) + tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) .await .unwrap() .unwrap(); @@ -1433,9 +1433,12 @@ mod pd_mode_tests { cors_allowed_origins: vec![], }; + // Create app context + let app_context = common::create_test_context(config); + // Create router - this might fail due to health check issues let router_result = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) + tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) .await .unwrap(); diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 436b57a6c..4ca499e84 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -1,2 +1,15 @@ pub mod mock_worker; pub mod test_app; + +use sglang_router_rs::config::RouterConfig; +use sglang_router_rs::server::AppContext; +use std::sync::Arc; + +/// Helper function to create AppContext for tests +pub fn create_test_context(config: RouterConfig) -> Arc { + Arc::new(AppContext::new( + config.clone(), + reqwest::Client::new(), + config.max_concurrent_requests, + )) +} diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index d4a001ce3..7c4cf76eb 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -3,7 +3,7 @@ use reqwest::Client; use sglang_router_rs::{ config::RouterConfig, routers::RouterTrait, - server::{build_app, AppState}, + server::{build_app, AppContext, AppState}, }; use std::sync::Arc; @@ -13,13 +13,17 @@ pub fn create_test_app( client: Client, router_config: &RouterConfig, ) -> Router { - // Create AppState with the test router + // Create AppContext + let app_context = Arc::new(AppContext::new( + router_config.clone(), + client, + router_config.max_concurrent_requests, + )); + + // Create AppState with the test router and context let app_state = Arc::new(AppState { router, - client, - _concurrency_limiter: Arc::new(tokio::sync::Semaphore::new( - router_config.max_concurrent_requests, - )), + context: app_context, }); // Configure request ID headers (use defaults if not specified) diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 320ad893e..a3cd12edb 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -53,10 +53,12 @@ impl TestContext { config.mode = RoutingMode::Regular { worker_urls }; - let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) - .await - .unwrap() - .unwrap(); + let app_context = common::create_test_context(config); + let router = + tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) + .await + .unwrap() + .unwrap(); let router = Arc::from(router); if !workers.is_empty() { diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index b64aa9a4a..2ef2e0929 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -54,10 +54,12 @@ impl TestContext { config.mode = RoutingMode::Regular { worker_urls }; - let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) - .await - .unwrap() - .unwrap(); + let app_context = common::create_test_context(config); + let router = + tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) + .await + .unwrap() + .unwrap(); let router = Arc::from(router); if !workers.is_empty() { diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index aea6df4d3..d0877eeb8 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -181,7 +181,10 @@ mod test_pd_routing { }; // Router creation will fail due to health checks, but config should be valid - let result = RouterFactory::create_router(&config); + let app_context = + sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64); + let app_context = std::sync::Arc::new(app_context); + let result = RouterFactory::create_router(&app_context); assert!(result.is_err()); let error_msg = result.unwrap_err(); // Error should be about health/timeout, not configuration