From e7387035476eb2c57fd49608066abf2e5f7551ac Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 24 Sep 2025 22:13:31 -0400 Subject: [PATCH] [router] consolidate worker get loads (#10880) --- sgl-router/src/core/worker_manager.rs | 98 +++++++++++++- sgl-router/src/protocols/worker_spec.rs | 25 ++++ sgl-router/src/routers/grpc/pd_router.rs | 4 - sgl-router/src/routers/grpc/router.rs | 4 - sgl-router/src/routers/http/openai_router.rs | 8 -- sgl-router/src/routers/http/pd_router.rs | 111 +-------------- sgl-router/src/routers/http/router.rs | 135 +------------------ sgl-router/src/routers/mod.rs | 3 - sgl-router/src/routers/router_manager.rs | 24 ---- sgl-router/src/server.rs | 25 +++- 10 files changed, 157 insertions(+), 280 deletions(-) diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index b5f4da12c..9dde2aca8 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -12,7 +12,9 @@ use crate::core::{ Worker, WorkerFactory, WorkerRegistry, WorkerType, }; use crate::policies::PolicyRegistry; -use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest}; +use crate::protocols::worker_spec::{ + FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, +}; use crate::server::AppContext; use futures::future; use once_cell::sync::Lazy; @@ -1079,6 +1081,100 @@ impl WorkerManager { message, }) } + pub async fn get_worker_load( + url: &str, + api_key: Option<&str>, + client: &reqwest::Client, + ) -> Option { + let load_url = format!("{}/get_load", url); + let mut request = client.get(&load_url); + + if let Some(key) = api_key { + request = request.bearer_auth(key); + } + + match request.send().await { + Ok(response) if response.status().is_success() => { + match response.json::().await { + Ok(json) => { + if let Some(load) = json.get("load").and_then(|v| v.as_i64()) { + debug!("Worker {} load: {}", url, load); + Some(load as isize) + } else { + warn!("Invalid load response from {}: {:?}", url, json); + None + } + } + Err(e) => { + warn!("Failed to parse load response from {}: {}", url, e); + None + } + } + } + Ok(response) => { + warn!( + "Failed to get load from {}: HTTP {}", + url, + response.status() + ); + None + } + Err(e) => { + warn!("Failed to connect to {} for load check: {}", url, e); + None + } + } + } + + pub async fn get_all_worker_loads( + worker_registry: &WorkerRegistry, + client: &reqwest::Client, + ) -> WorkerLoadsResult { + let workers = worker_registry.get_all(); + let total_workers = workers.len(); + + // Prepare tasks for parallel execution + let mut tasks = Vec::new(); + for worker in &workers { + let url = worker.url().to_string(); + let api_key = worker.api_key().clone(); + let worker_type = match worker.worker_type() { + WorkerType::Regular => None, + WorkerType::Prefill { .. } => Some("prefill".to_string()), + WorkerType::Decode => Some("decode".to_string()), + }; + let is_http = matches!(worker.connection_mode(), ConnectionMode::Http); + let client = client.clone(); + + tasks.push(async move { + let load = if is_http { + Self::get_worker_load(&url, api_key.as_deref(), &client) + .await + .unwrap_or(-1) + } else { + -1 + }; + + WorkerLoadInfo { + worker: url, + worker_type, + load, + } + }); + } + + let loads = futures::future::join_all(tasks).await; + + let successful = loads.iter().filter(|l| l.load >= 0).count(); + let failed = loads.iter().filter(|l| l.load < 0).count(); + + WorkerLoadsResult { + loads, + total_workers, + successful, + failed, + } + } } #[cfg(test)] diff --git a/sgl-router/src/protocols/worker_spec.rs b/sgl-router/src/protocols/worker_spec.rs index 9ef35ba4c..e76a8f64e 100644 --- a/sgl-router/src/protocols/worker_spec.rs +++ b/sgl-router/src/protocols/worker_spec.rs @@ -215,3 +215,28 @@ pub struct FlushCacheResult { /// Human-readable summary message pub message: String, } + +/// Result from getting worker loads +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WorkerLoadsResult { + /// Worker URL and load pairs + pub loads: Vec, + /// Total number of workers + pub total_workers: usize, + /// Number of workers with successful load fetches + pub successful: usize, + /// Number of workers with failed load fetches + pub failed: usize, +} + +/// Individual worker load information +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WorkerLoadInfo { + /// Worker URL + pub worker: String, + /// Worker type (regular, prefill, decode) + #[serde(skip_serializing_if = "Option::is_none")] + pub worker_type: Option, + /// Current load (-1 indicates failure to fetch) + pub load: isize, +} diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index e4fe54e5f..a60744518 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -340,10 +340,6 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn get_worker_loads(&self) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() - } - fn router_type(&self) -> &'static str { "grpc_pd" } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index a229b891e..5266bf913 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -787,10 +787,6 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn get_worker_loads(&self) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() - } - fn router_type(&self) -> &'static str { "grpc" } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 4961afeca..035af37dc 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -1296,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter { } } - async fn get_worker_loads(&self) -> Response { - ( - StatusCode::FORBIDDEN, - "get_worker_loads not supported for OpenAI router", - ) - .into_response() - } - fn router_type(&self) -> &'static str { "openai" } diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 155fd600a..f655c7923 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1,8 +1,8 @@ use super::pd_types::api_path; use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, - WorkerType, + is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager, + WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -18,7 +18,6 @@ use axum::{ extract::Request, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, - Json, }; use futures_util::StreamExt; use reqwest::Client; @@ -53,26 +52,6 @@ struct PDRequestContext<'a> { } impl PDRouter { - fn _get_worker_url_and_key(&self, w: &Arc) -> (String, Option) { - (w.url().to_string(), w.api_key().clone()) - } - - fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option)> { - self.worker_registry - .get_prefill_workers() - .iter() - .map(|w| self._get_worker_url_and_key(w)) - .collect() - } - - fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option)> { - self.worker_registry - .get_decode_workers() - .iter() - .map(|w| self._get_worker_url_and_key(w)) - .collect() - } - async fn proxy_to_first_prefill_worker( &self, endpoint: &str, @@ -749,7 +728,10 @@ impl PDRouter { let url = url.clone(); let api_key = api_key.clone(); async move { - let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0); + let load = + WorkerManager::get_worker_load(&url, api_key.as_deref(), &client) + .await + .unwrap_or(0); (url, load) } }) @@ -1083,49 +1065,6 @@ impl PDRouter { } } -// Helper functions - -async fn get_worker_load( - client: &Client, - worker_url: &str, - api_key: &Option, -) -> Option { - let mut req_builder = client.get(format!("{}/get_load", worker_url)); - if let Some(key) = api_key { - req_builder = req_builder.bearer_auth(key); - } - match req_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(bytes) => match serde_json::from_slice::(&bytes) { - Ok(data) => data - .get("load") - .and_then(|v| v.as_i64()) - .map(|v| v as isize), - Err(e) => { - debug!("Failed to parse load response from {}: {}", worker_url, e); - None - } - }, - Err(e) => { - debug!("Failed to read load response from {}: {}", worker_url, e); - None - } - }, - Ok(res) => { - debug!( - "Worker {} returned non-success status: {}", - worker_url, - res.status() - ); - None - } - Err(e) => { - debug!("Failed to get load from {}: {}", worker_url, e); - None - } - } -} - #[async_trait] impl RouterTrait for PDRouter { fn as_any(&self) -> &dyn std::any::Any { @@ -1418,44 +1357,6 @@ impl RouterTrait for PDRouter { self.execute_dual_dispatch(headers, body, context).await } - async fn get_worker_loads(&self) -> Response { - let mut loads = HashMap::new(); - let mut errors = Vec::new(); - - // Process prefill workers - let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key(); - for (worker_url, api_key) in prefill_urls_with_key { - match get_worker_load(&self.client, &worker_url, &api_key).await { - Some(load) => { - loads.insert(format!("prefill_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from prefill {}", worker_url)); - } - } - } - - // Process decode workers - let decode_urls_with_key = self.get_decode_worker_urls_with_api_key(); - for (worker_url, api_key) in decode_urls_with_key { - match get_worker_load(&self.client, &worker_url, &api_key).await { - Some(load) => { - loads.insert(format!("decode_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from decode {}", worker_url)); - } - } - } - - let response_data = serde_json::json!({ - "loads": loads, - "errors": errors - }); - - (StatusCode::OK, Json(response_data)).into_response() - } - fn router_type(&self) -> &'static str { "pd" } diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index c98dd2a4c..8be3490fb 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,6 +1,7 @@ use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, + is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry, + WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -660,58 +661,6 @@ impl Router { } } - async fn get_worker_load(&self, worker_url: &str, api_key: &Option) -> Option { - let worker_url = if self.dp_aware { - // Need to extract the URL from "http://host:port@dp_rank" - let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { - Ok(tup) => tup, - Err(e) => { - error!("Failed to extract dp_rank: {}", e); - return None; - } - }; - worker_url_prefix - } else { - worker_url - }; - - let mut req_builder = self.client.get(format!("{}/get_load", worker_url)); - if let Some(key) = api_key { - req_builder = req_builder.bearer_auth(key); - } - - match req_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(bytes) => match serde_json::from_slice::(&bytes) { - Ok(data) => data - .get("load") - .and_then(|v| v.as_i64()) - .map(|v| v as isize), - Err(e) => { - debug!("Failed to parse load response from {}: {}", worker_url, e); - None - } - }, - Err(e) => { - debug!("Failed to read load response from {}: {}", worker_url, e); - None - } - }, - Ok(res) => { - debug!( - "Worker {} returned non-success status: {}", - worker_url, - res.status() - ); - None - } - Err(e) => { - debug!("Failed to get load from {}: {}", worker_url, e); - None - } - } - } - // Background task to monitor worker loads async fn monitor_worker_loads( worker_urls: Vec, @@ -728,7 +677,10 @@ impl Router { let mut loads = HashMap::new(); for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) { - if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await { + // Use WorkerManager for consistent load fetching + if let Some(load) = + WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await + { loads.insert(url.clone(), load); } } @@ -745,62 +697,6 @@ impl Router { } } - // Static version of get_worker_load for use in monitoring task - async fn get_worker_load_static( - client: &Client, - worker_url: &str, - api_key: &Option, - ) -> Option { - let worker_url = if worker_url.contains("@") { - // Need to extract the URL from "http://host:port@dp_rank" - let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { - Ok(tup) => tup, - Err(e) => { - debug!("Failed to extract dp_rank: {}", e); - return None; - } - }; - worker_url_prefix - } else { - worker_url - }; - - let mut req_builder = client.get(format!("{}/get_load", worker_url)); - if let Some(key) = api_key { - req_builder = req_builder.bearer_auth(key); - } - match req_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(bytes) => match serde_json::from_slice::(&bytes) { - Ok(data) => data - .get("load") - .and_then(|v| v.as_i64()) - .map(|v| v as isize), - Err(e) => { - debug!("Failed to parse load response from {}: {}", worker_url, e); - None - } - }, - Err(e) => { - debug!("Failed to read load response from {}: {}", worker_url, e); - None - } - }, - Ok(res) => { - debug!( - "Worker {} returned non-success status: {}", - worker_url, - res.status() - ); - None - } - Err(e) => { - debug!("Failed to get load from {}: {}", worker_url, e); - None - } - } - } - async fn build_rerank_response( req: &RerankRequest, response: Response, @@ -953,25 +849,6 @@ impl RouterTrait for Router { } } - async fn get_worker_loads(&self) -> Response { - let urls_with_key = self.worker_registry.get_all_urls_with_api_key(); - let mut loads = Vec::new(); - - // Get loads from all workers - for (url, api_key) in &urls_with_key { - let load = self.get_worker_load(url, api_key).await.unwrap_or(-1); - loads.push(serde_json::json!({ - "worker": url, - "load": load - })); - } - - Json(serde_json::json!({ - "workers": loads - })) - .into_response() - } - fn router_type(&self) -> &'static str { "regular" } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 656d7979f..212afbfcf 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -126,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug { model_id: Option<&str>, ) -> Response; - /// Get worker loads (for monitoring) - async fn get_worker_loads(&self) -> Response; - /// Get router type name fn router_type(&self) -> &'static str; diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 74f902444..5ad875212 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -508,30 +508,6 @@ impl RouterTrait for RouterManager { } } - async fn get_worker_loads(&self) -> Response { - let workers = self.worker_registry.get_all(); - let loads: Vec = workers - .iter() - .map(|w| { - serde_json::json!({ - "url": w.url(), - "model": w.model_id(), - "load": w.load(), - "is_healthy": w.is_healthy() - }) - }) - .collect(); - - ( - StatusCode::OK, - serde_json::json!({ - "workers": loads - }) - .to_string(), - ) - .into_response() - } - fn router_type(&self) -> &'static str { "manager" } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 300057741..7b9a5dd4a 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -28,7 +28,7 @@ use axum::{ }; use reqwest::Client; use serde::Deserialize; -use serde_json::json; +use serde_json::{json, Value}; use std::{ sync::atomic::{AtomicBool, Ordering}, sync::Arc, @@ -400,7 +400,28 @@ async fn flush_cache(State(state): State>, _req: Request) -> Respo } async fn get_loads(State(state): State>, _req: Request) -> Response { - state.router.get_worker_loads().await + let result = + WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client) + .await; + + let loads: Vec = result + .loads + .iter() + .map(|info| { + json!({ + "worker": &info.worker, + "load": info.load + }) + }) + .collect(); + + ( + StatusCode::OK, + Json(json!({ + "workers": loads + })), + ) + .into_response() } async fn create_worker(