From 38907fe639047fa21dfa22eadbeb7512b1ecd053 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 11 Aug 2025 13:32:31 -0700 Subject: [PATCH] refactor(pd-router): extract common patterns to reduce code duplication (#9081) --- sgl-router/src/routers/pd_router.rs | 458 +++++++++++----------------- 1 file changed, 184 insertions(+), 274 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index cd36bb5cc..b0347e59f 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -72,6 +72,138 @@ impl PDRouter { }) } + // Generic helper for processing all workers with an endpoint + async fn process_workers( + &self, + workers: &RwLock>>, + worker_type: &str, + endpoint: &str, + ) -> (Vec, Vec) { + let mut results = Vec::new(); + let mut errors = Vec::new(); + + // Get worker URLs first to avoid holding lock across await + let urls = match workers.read() { + Ok(workers) => workers + .iter() + .map(|w| w.url().to_string()) + .collect::>(), + Err(_) => { + errors.push(format!("Failed to access {} workers", worker_type)); + Vec::new() + } + }; + + // Process each worker + for worker_url in urls { + let url = format!("{}/{}", worker_url, endpoint); + match self.client.post(&url).send().await { + Ok(res) if res.status().is_success() => { + results.push(format!("{} {}: OK", worker_type, worker_url)); + } + Ok(res) => { + errors.push(format!( + "{} {} returned status: {}", + worker_type, + worker_url, + res.status() + )); + } + Err(e) => { + errors.push(format!("{} {} error: {}", worker_type, worker_url, e)); + } + } + } + + (results, errors) + } + + // Helper to get worker URLs from a worker collection + fn get_worker_urls( + workers: &RwLock>>, + worker_type: &str, + ) -> Result, String> { + workers + .read() + .map(|workers| { + workers + .iter() + .map(|w| w.url().to_string()) + .collect::>() + }) + .map_err(|_| format!("Failed to access {} workers", worker_type)) + } + + // Generic helper for proxying requests to the first worker + async fn proxy_to_first_worker( + &self, + workers: &RwLock>>, + endpoint: &str, + worker_type: &str, + headers: Option>, + ) -> Response { + // Get first worker URL to avoid holding lock across await + let first_worker_url = match workers.read() { + Ok(workers) => workers.first().map(|w| w.url().to_string()), + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to access {} workers", worker_type), + ) + .into_response(); + } + }; + + if let Some(worker_url) = first_worker_url { + let url = format!("{}/{}", worker_url, endpoint); + let mut request_builder = self.client.get(&url); + + // Add headers if provided + if let Some(headers) = headers { + for (name, value) in headers { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(body) => (StatusCode::OK, body).into_response(), + Err(e) => { + error!("Failed to read response body: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response() + } + }, + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + ( + status, + format!("{} server returned status: {}", worker_type, res.status()), + ) + .into_response() + } + Err(e) => { + error!("Failed to proxy request to {} server: {}", worker_type, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to proxy request: {}", e), + ) + .into_response() + } + } + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("No {} servers available", worker_type), + ) + .into_response() + } + } + pub async fn add_prefill_server( &self, url: String, @@ -1384,191 +1516,32 @@ impl RouterTrait for PDRouter { async fn get_server_info(&self, _req: Request) -> Response { // Get info from the first decode server to match sglang's server info format - let first_decode_url = if let Ok(workers) = self.decode_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access decode workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_decode_url { - match self - .client - .get(format!("{}/get_server_info", worker_url)) - .send() - .await - { - Ok(res) if res.status().is_success() => { - match res.json::().await { - Ok(info) => { - // The decode server should already return the proper format - // with tokenizer_path and other fields that bench_one_batch_server.py expects - Json(info).into_response() - } - Err(e) => { - error!("Failed to parse server info: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to parse server info: {}", e), - ) - .into_response() - } - } - } - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - ( - status, - format!("Decode server returned status: {}", res.status()), - ) - .into_response() - } - Err(e) => { - error!("Failed to get server info: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get server info: {}", e), - ) - .into_response() - } - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No decode servers available", - ) - .into_response() - } + // Note: We use decode workers for server info to match expected format + self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None) + .await } async fn get_models(&self, req: Request) -> Response { // Extract headers first to avoid Send issues let headers = crate::routers::router::copy_request_headers(&req); - // Get first prefill worker URL to avoid holding lock across await - let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access prefill workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_worker_url { - let url = format!("{}/v1/models", worker_url); - let mut request_builder = self.client.get(&url); - - // Add headers - for (name, value) in headers { - request_builder = request_builder.header(name, value); - } - - match request_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(body) => (StatusCode::OK, body).into_response(), - Err(e) => { - error!("Failed to read response body: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response() - } - }, - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - ( - status, - format!("Prefill server returned status: {}", res.status()), - ) - .into_response() - } - Err(e) => { - error!("Failed to get models: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get models: {}", e), - ) - .into_response() - } - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No prefill servers available", - ) - .into_response() - } + // Proxy to first prefill worker + self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers)) + .await } async fn get_model_info(&self, req: Request) -> Response { // Extract headers first to avoid Send issues let headers = crate::routers::router::copy_request_headers(&req); - // Get first prefill worker URL to avoid holding lock across await - let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access prefill workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_worker_url { - let url = format!("{}/get_model_info", worker_url); - let mut request_builder = self.client.get(&url); - - // Add headers - for (name, value) in headers { - request_builder = request_builder.header(name, value); - } - - match request_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(body) => (StatusCode::OK, body).into_response(), - Err(e) => { - error!("Failed to read response body: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response() - } - }, - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - ( - status, - format!("Prefill server returned status: {}", res.status()), - ) - .into_response() - } - Err(e) => { - error!("Failed to get model info: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get model info: {}", e), - ) - .into_response() - } - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No prefill servers available", - ) - .into_response() - } + // Proxy to first prefill worker + self.proxy_to_first_worker( + &self.prefill_workers, + "get_model_info", + "prefill", + Some(headers), + ) + .await } async fn route_generate( @@ -1692,70 +1665,19 @@ impl RouterTrait for PDRouter { } async fn flush_cache(&self) -> Response { - let mut results = Vec::new(); - let mut errors = Vec::new(); + // Process both prefill and decode workers + let (prefill_results, prefill_errors) = self + .process_workers(&self.prefill_workers, "Prefill", "flush_cache") + .await; + let (decode_results, decode_errors) = self + .process_workers(&self.decode_workers, "Decode", "flush_cache") + .await; - // Get prefill worker URLs first to avoid holding lock across await - let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { - workers - .iter() - .map(|w| w.url().to_string()) - .collect::>() - } else { - errors.push("Failed to access prefill workers".to_string()); - Vec::new() - }; - - // Flush prefill workers - for worker_url in prefill_urls { - let url = format!("{}/flush_cache", worker_url); - match self.client.post(&url).send().await { - Ok(res) if res.status().is_success() => { - results.push(format!("Prefill {}: OK", worker_url)); - } - Ok(res) => { - errors.push(format!( - "Prefill {} returned status: {}", - worker_url, - res.status() - )); - } - Err(e) => { - errors.push(format!("Prefill {} error: {}", worker_url, e)); - } - } - } - - // Get decode worker URLs first to avoid holding lock across await - let decode_urls = if let Ok(workers) = self.decode_workers.read() { - workers - .iter() - .map(|w| w.url().to_string()) - .collect::>() - } else { - errors.push("Failed to access decode workers".to_string()); - Vec::new() - }; - - // Flush decode workers - for worker_url in decode_urls { - let url = format!("{}/flush_cache", worker_url); - match self.client.post(&url).send().await { - Ok(res) if res.status().is_success() => { - results.push(format!("Decode {}: OK", worker_url)); - } - Ok(res) => { - errors.push(format!( - "Decode {} returned status: {}", - worker_url, - res.status() - )); - } - Err(e) => { - errors.push(format!("Decode {} error: {}", worker_url, e)); - } - } - } + // Combine results and errors + let mut results = prefill_results; + results.extend(decode_results); + let mut errors = prefill_errors; + errors.extend(decode_errors); if errors.is_empty() { ( @@ -1779,50 +1701,38 @@ impl RouterTrait for PDRouter { let mut loads = HashMap::new(); let mut errors = Vec::new(); - // Get prefill worker URLs first to avoid holding lock across await - let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { - workers - .iter() - .map(|w| w.url().to_string()) - .collect::>() - } else { - errors.push("Failed to access prefill workers".to_string()); - Vec::new() - }; - - // Get loads from prefill workers - for worker_url in prefill_urls { - match get_worker_load(&self.client, &worker_url).await { - Some(load) => { - loads.insert(format!("prefill_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from prefill {}", worker_url)); + // Process prefill workers + match Self::get_worker_urls(&self.prefill_workers, "prefill") { + Ok(urls) => { + for worker_url in urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("prefill_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from prefill {}", worker_url)); + } + } } } + Err(e) => errors.push(e), } - // Get decode worker URLs first to avoid holding lock across await - let decode_urls = if let Ok(workers) = self.decode_workers.read() { - workers - .iter() - .map(|w| w.url().to_string()) - .collect::>() - } else { - errors.push("Failed to access decode workers".to_string()); - Vec::new() - }; - - // Get loads from decode workers - for worker_url in decode_urls { - match get_worker_load(&self.client, &worker_url).await { - Some(load) => { - loads.insert(format!("decode_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from decode {}", worker_url)); + // Process decode workers + match Self::get_worker_urls(&self.decode_workers, "decode") { + Ok(urls) => { + for worker_url in urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("decode_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from decode {}", worker_url)); + } + } } } + Err(e) => errors.push(e), } let response_data = serde_json::json!({