diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index fabbebc26..a52e124ad 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -39,6 +39,8 @@ pub struct RouterConfig { pub max_concurrent_requests: usize, /// CORS allowed origins pub cors_allowed_origins: Vec, + /// Retry configuration + pub retry: RetryConfig, } /// Routing mode configuration @@ -182,6 +184,30 @@ impl Default for DiscoveryConfig { } } +/// Retry configuration for request handling +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryConfig { + /// Maximum number of retry attempts + pub max_retries: u32, + /// Initial backoff delay in milliseconds + pub initial_backoff_ms: u64, + /// Maximum backoff delay in milliseconds + pub max_backoff_ms: u64, + /// Backoff multiplier for exponential backoff + pub backoff_multiplier: f32, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + initial_backoff_ms: 100, + max_backoff_ms: 10000, + backoff_multiplier: 2.0, + } + } +} + /// Metrics configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetricsConfig { @@ -210,7 +236,7 @@ impl Default for RouterConfig { host: "127.0.0.1".to_string(), port: 3001, max_payload_size: 268_435_456, // 256MB - request_timeout_secs: 600, + request_timeout_secs: 3600, // 1 hour to match Python mini LB worker_startup_timeout_secs: 300, worker_startup_check_interval_secs: 10, dp_aware: false, @@ -222,6 +248,7 @@ impl Default for RouterConfig { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), } } } @@ -277,7 +304,7 @@ mod tests { assert_eq!(config.host, "127.0.0.1"); assert_eq!(config.port, 3001); assert_eq!(config.max_payload_size, 268_435_456); - assert_eq!(config.request_timeout_secs, 600); + assert_eq!(config.request_timeout_secs, 3600); assert_eq!(config.worker_startup_timeout_secs, 300); assert_eq!(config.worker_startup_check_interval_secs, 10); assert!(config.discovery.is_none()); @@ -332,6 +359,7 @@ mod tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; let json = serde_json::to_string(&config).unwrap(); @@ -759,6 +787,7 @@ mod tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; assert!(config.mode.is_pd_mode()); @@ -810,6 +839,7 @@ mod tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; assert!(!config.mode.is_pd_mode()); @@ -857,6 +887,7 @@ mod tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index a61ba7e45..290fbda9a 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -19,7 +19,7 @@ pub enum PolicyType { Random, RoundRobin, CacheAware, - PowerOfTwo, // Moved from PD-specific, now shared + PowerOfTwo, } #[pyclass] @@ -45,7 +45,6 @@ struct Router { selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, - // PD service discovery fields prefill_selector: HashMap, decode_selector: HashMap, bootstrap_port_annotation: String, @@ -53,14 +52,11 @@ struct Router { prometheus_host: Option, request_timeout_secs: u64, request_id_headers: Option>, - // PD mode flag pd_disaggregation: bool, - // PD-specific fields (only used when pd_disaggregation is true) prefill_urls: Option)>>, decode_urls: Option>, prefill_policy: Option, decode_policy: Option, - // Additional server config fields max_concurrent_requests: usize, cors_allowed_origins: Vec, } @@ -150,6 +146,7 @@ impl Router { request_id_headers: self.request_id_headers.clone(), max_concurrent_requests: self.max_concurrent_requests, cors_allowed_origins: self.cors_allowed_origins.clone(), + retry: config::RetryConfig::default(), }) } } @@ -289,7 +286,6 @@ impl Router { check_interval: std::time::Duration::from_secs(60), port: self.service_discovery_port, namespace: self.service_discovery_namespace.clone(), - // PD mode configuration pd_mode: self.pd_disaggregation, prefill_selector: self.prefill_selector.clone(), decode_selector: self.decode_selector.clone(), diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 8dc40527a..e67ce6650 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -50,6 +50,7 @@ impl RouterFactory { ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.dp_aware, ctx.router_config.api_key.clone(), + ctx.router_config.retry.clone(), )?; Ok(Box::new(router)) @@ -79,6 +80,7 @@ impl RouterFactory { ctx.client.clone(), ctx.router_config.worker_startup_timeout_secs, ctx.router_config.worker_startup_check_interval_secs, + ctx.router_config.retry.clone(), )?; Ok(Box::new(router)) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index b799237a9..dccb68e8f 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -3,6 +3,7 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError}; use super::request_adapter::ToPdRequest; +use crate::config::types::RetryConfig; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; @@ -16,6 +17,8 @@ use axum::{ Json, }; use futures_util::StreamExt; +use rand::Rng; +use reqwest::Client; use serde_json::Value; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; @@ -36,6 +39,7 @@ pub struct PDRouter { pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, pub client: Client, + pub retry_config: RetryConfig, _prefill_health_checker: Option, _decode_health_checker: Option, } @@ -180,6 +184,7 @@ impl PDRouter { client: Client, timeout_secs: u64, interval_secs: u64, + retry_config: RetryConfig, ) -> Result { // Convert URLs to Worker trait objects let prefill_workers: Vec> = prefill_urls @@ -260,6 +265,7 @@ impl PDRouter { worker_loads, load_monitor_handle, client, + retry_config, _prefill_health_checker: Some(prefill_health_checker), _decode_health_checker: Some(decode_health_checker), }) @@ -294,6 +300,38 @@ impl PDRouter { } } + // Helper to handle server selection errors + fn handle_server_selection_error(error: String) -> Response { + error!("Failed to select PD pair error={}", error); + RouterMetrics::record_pd_error("server_selection"); + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("No available servers: {}", error), + ) + .into_response() + } + + // Helper to handle bootstrap injection errors + fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response { + error!("Failed to add bootstrap info error={}", error); + RouterMetrics::record_pd_error("bootstrap_injection"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Bootstrap injection failed: {}", error), + ) + .into_response() + } + + // Helper to handle serialization errors + fn handle_serialization_error(error: impl std::fmt::Display) -> Response { + error!("Failed to serialize request error={}", error); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to serialize request", + ) + .into_response() + } + // Route a typed generate request pub async fn route_generate( &self, @@ -320,15 +358,7 @@ impl PDRouter { // Select servers let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, - Err(e) => { - error!("Failed to select PD pair error={}", e); - RouterMetrics::record_pd_error("server_selection"); - return ( - StatusCode::SERVICE_UNAVAILABLE, - format!("No available servers: {}", e), - ) - .into_response(); - } + Err(e) => return Self::handle_server_selection_error(e), }; // Log routing decision @@ -341,26 +371,13 @@ impl PDRouter { // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info error={}", e); - RouterMetrics::record_pd_error("bootstrap_injection"); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Bootstrap injection failed: {}", e), - ) - .into_response(); + return Self::handle_bootstrap_error(e); } // Convert to JSON after bootstrap injection let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, - Err(e) => { - error!("Failed to serialize request error={}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to serialize request", - ) - .into_response(); - } + Err(e) => return Self::handle_serialization_error(e), }; // Execute dual dispatch @@ -406,15 +423,7 @@ impl PDRouter { // Select servers let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, - Err(e) => { - error!("Failed to select PD pair error={}", e); - RouterMetrics::record_pd_error("server_selection"); - return ( - StatusCode::SERVICE_UNAVAILABLE, - format!("No available servers: {}", e), - ) - .into_response(); - } + Err(e) => return Self::handle_server_selection_error(e), }; // Log routing decision @@ -425,28 +434,14 @@ impl PDRouter { decode.url() ); - // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info error={}", e); - RouterMetrics::record_pd_error("bootstrap_injection"); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Bootstrap injection failed: {}", e), - ) - .into_response(); + return Self::handle_bootstrap_error(e); } // Convert to JSON after bootstrap injection let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, - Err(e) => { - error!("Failed to serialize request error={}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to serialize request", - ) - .into_response(); - } + Err(e) => return Self::handle_serialization_error(e), }; // Execute dual dispatch @@ -485,15 +480,7 @@ impl PDRouter { // Select servers let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, - Err(e) => { - error!("Failed to select PD pair error={}", e); - RouterMetrics::record_pd_error("server_selection"); - return ( - StatusCode::SERVICE_UNAVAILABLE, - format!("No available servers: {}", e), - ) - .into_response(); - } + Err(e) => return Self::handle_server_selection_error(e), }; // Log routing decision @@ -504,28 +491,14 @@ impl PDRouter { decode.url() ); - // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info error={}", e); - RouterMetrics::record_pd_error("bootstrap_injection"); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Bootstrap injection failed: {}", e), - ) - .into_response(); + return Self::handle_bootstrap_error(e); } // Convert to JSON after bootstrap injection let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, - Err(e) => { - error!("Failed to serialize request error={}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to serialize request", - ) - .into_response(); - } + Err(e) => return Self::handle_serialization_error(e), }; // Execute dual dispatch @@ -542,7 +515,7 @@ impl PDRouter { .await } - // Execute the dual dispatch to prefill and decode servers + // Execute the dual dispatch to prefill and decode servers with retry logic async fn execute_dual_dispatch( &self, headers: Option<&HeaderMap>, @@ -553,38 +526,128 @@ impl PDRouter { is_stream: bool, return_logprob: bool, start_time: Instant, + ) -> Response { + for attempt in 0..self.retry_config.max_retries { + if attempt > 0 { + // Calculate backoff with exponential growth and jitter + let base_backoff = self.retry_config.initial_backoff_ms as f64 + * self + .retry_config + .backoff_multiplier + .powf((attempt - 1) as f32) as f64; + let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64; + + // Add jitter to prevent thundering herd + let jitter = { + let mut rng = rand::thread_rng(); + rng.gen_range(0..backoff_ms / 2) + }; + let total_backoff = Duration::from_millis(backoff_ms + jitter); + + info!( + "Retrying request (attempt {}/{}) after {:?} backoff", + attempt + 1, + self.retry_config.max_retries, + total_backoff + ); + + tokio::time::sleep(total_backoff).await; + } + + debug!( + "Executing request attempt {}/{}", + attempt + 1, + self.retry_config.max_retries + ); + let result = self + .execute_dual_dispatch_inner( + headers, + json_request.clone(), + route, + prefill, + decode, + is_stream, + return_logprob, + start_time, + ) + .await; + + // Check if we should retry based on the response status + let status = result.status(); + debug!( + "Request attempt {} returned status: {}", + attempt + 1, + status + ); + + // Don't retry client errors (4xx) or successful responses + if status.is_client_error() || status.is_success() { + debug!( + "Returning response with status {} (no retry needed)", + status + ); + return result; + } + + // Check if this is the last attempt + if attempt == self.retry_config.max_retries - 1 { + warn!("Final attempt failed with status {}", status); + return result; + } + + // Log retry decision for retryable errors + if status.is_server_error() + || status == StatusCode::BAD_GATEWAY + || status == StatusCode::GATEWAY_TIMEOUT + { + warn!( + "Retryable error status: {} on attempt {}/{}. Will retry.", + status, + attempt + 1, + self.retry_config.max_retries + ); + } else { + // Don't retry other statuses + debug!("Status {} is not retryable, returning response", status); + return result; + } + } + + // This should never be reached due to the loop logic, but just in case + unreachable!("Retry loop completed without returning") + } + + // Inner implementation of dual dispatch (extracted for retry logic) + async fn execute_dual_dispatch_inner( + &self, + headers: Option<&HeaderMap>, + json_request: Value, + route: &str, + prefill: &dyn Worker, + decode: &dyn Worker, + is_stream: bool, + return_logprob: bool, + start_time: Instant, ) -> Response { // Update load tracking for both workers let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); - // Build requests using .json() method - let mut prefill_request = self - .client - .post(api_path(prefill.url(), route)) - .json(&json_request); + // Build requests with headers + let prefill_request = + self.build_request_with_headers(prefill.url(), route, &json_request, headers); - let mut decode_request = self - .client - .post(api_path(decode.url(), route)) - .json(&json_request); - - // Copy headers from original request (excluding content-type and content-length which are set by .json()) - if let Some(headers) = headers { - for (name, value) in headers.iter() { - let name_str = name.as_str(); - if name_str != "content-type" && name_str != "content-length" { - // Skip headers with non-ASCII values - if value.to_str().is_ok() { - prefill_request = prefill_request.header(name, value); - decode_request = decode_request.header(name, value); - } - } - } - } + let decode_request = + self.build_request_with_headers(decode.url(), route, &json_request, headers); // Send both requests concurrently + debug!( + "Sending concurrent requests to prefill={} decode={}", + prefill.url(), + decode.url() + ); let (prefill_result, decode_result) = tokio::join!(prefill_request.send(), decode_request.send()); + debug!("Received responses from both servers"); // Update metrics let duration = start_time.elapsed(); @@ -593,11 +656,22 @@ impl PDRouter { RouterMetrics::record_pd_prefill_request(prefill.url()); RouterMetrics::record_pd_decode_request(decode.url()); + // Process prefill response + let (_prefill_status, prefill_body) = match self + .process_prefill_response(prefill_result, prefill.url(), return_logprob) + .await + { + Ok(result) => result, + Err(error_response) => return error_response, + }; + // Process decode response + debug!("Processing decode response"); match decode_result { Ok(res) => { let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + debug!("Decode response status: {}", status); if !status.is_success() { RouterMetrics::record_pd_decode_error(decode.url()); @@ -618,128 +692,36 @@ impl PDRouter { } } - // Log prefill errors for debugging - if let Err(e) = &prefill_result { - error!( - "Prefill server failed (non-critical) prefill_url={} error={}", - prefill.url(), - e - ); - RouterMetrics::record_pd_prefill_error(prefill.url()); - } - if is_stream { // Streaming response - if return_logprob { - // Get prefill logprobs for merging - let prefill_logprobs = - match prefill_result { - Ok(prefill_res) => match prefill_res.bytes().await { - Ok(body) => serde_json::from_slice::(&body) - .ok() - .and_then(|json| { - json.pointer("/meta_info/input_token_logprobs").cloned() - }), - Err(_) => None, - }, - Err(_) => None, - }; - - // Stream with logprob merging - let stream = res.bytes_stream(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - tokio::spawn(async move { - let mut stream = stream; - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - // Try to merge logprobs - if let Ok(merged) = Self::merge_streaming_logprobs( - prefill_logprobs.clone(), - &chunk, - ) { - if tx.send(Ok(merged)).is_err() { - break; - } - } else { - if tx.send(Ok(chunk)).is_err() { - break; - } - } - } - Err(e) => { - let _ = tx.send(Err(format!("Stream error: {}", e))); - break; - } - } - } - }); - - let stream = UnboundedReceiverStream::new(rx); - let body = Body::from_stream(stream); - - let mut response = Response::new(body); - *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - response + let prefill_logprobs = if return_logprob { + prefill_body + .as_ref() + .and_then(|body| serde_json::from_slice::(body).ok()) + .and_then(|json| { + json.pointer("/meta_info/input_token_logprobs").cloned() + }) } else { - // No logprob merging needed - let stream = res.bytes_stream(); - let decode_url = decode.url().to_string(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + None + }; - tokio::spawn(async move { - let mut stream = stream; - while let Some(chunk) = stream.next().await { - match chunk { - Ok(bytes) => { - if tx.send(Ok(bytes)).is_err() { - break; - } - } - Err(e) => { - error!( - "Stream error from decode server {}: {}", - decode_url, e - ); - RouterMetrics::record_pd_stream_error(&decode_url); - let _ = tx.send(Err(format!("Stream error: {}", e))); - break; - } - } - } - }); + let decode_url = if !return_logprob { + Some(decode.url().to_string()) + } else { + None + }; - let stream = UnboundedReceiverStream::new(rx); - let body = Body::from_stream(stream); - - let mut response = Response::new(body); - *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - response - } + Self::create_streaming_response( + res.bytes_stream(), + status, + prefill_logprobs, + return_logprob, + decode_url, + ) } else { - // Non-streaming response - match res.bytes().await { - Ok(decode_body) => { - if return_logprob { - self.merge_logprobs(prefill_result, decode_body, status) - .await - } else { - (status, decode_body).into_response() - } - } - Err(e) => { - error!("Failed to read decode response: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response") - .into_response() - } - } + // Non-streaming response - use helper + self.process_non_streaming_response(res, status, return_logprob, prefill_body) + .await } } Err(e) => { @@ -758,62 +740,6 @@ impl PDRouter { } } - // Merge logprobs from prefill and decode responses - async fn merge_logprobs( - &self, - prefill_result: Result, - decode_body: bytes::Bytes, - status: StatusCode, - ) -> Response { - match prefill_result { - Ok(prefill_res) => { - match prefill_res.bytes().await { - Ok(prefill_body) => { - match ( - serde_json::from_slice::(&prefill_body), - serde_json::from_slice::(&decode_body), - ) { - (Ok(prefill_json), Ok(mut decode_json)) => { - // Merge input_token_logprobs - if let (Some(prefill_meta), Some(decode_meta)) = ( - prefill_json.get("meta_info"), - decode_json.get_mut("meta_info"), - ) { - if let (Some(prefill_logprobs), Some(decode_logprobs)) = ( - prefill_meta.get("input_token_logprobs"), - decode_meta.get_mut("input_token_logprobs"), - ) { - if let (Some(p_arr), Some(d_arr)) = ( - prefill_logprobs.as_array(), - decode_logprobs.as_array(), - ) { - let mut merged = p_arr.clone(); - merged.extend(d_arr.clone()); - decode_meta["input_token_logprobs"] = - Value::Array(merged); - } - } - } - let mut response = Json(decode_json).into_response(); - *response.status_mut() = status; - response - } - _ => { - warn!("Failed to parse responses for logprob merging"); - (status, decode_body).into_response() - } - } - } - Err(e) => { - warn!("Failed to read prefill response: {}", e); - (status, decode_body).into_response() - } - } - } - Err(_) => (status, decode_body).into_response(), - } - } - // Select a pair of prefill and decode servers async fn select_pd_pair( &self, @@ -900,6 +826,229 @@ impl PDRouter { } } + // Helper to create a streaming response + fn create_streaming_response( + stream: impl futures_util::Stream> + Send + 'static, + status: StatusCode, + prefill_logprobs: Option, + return_logprob: bool, + decode_url: Option, + ) -> Response { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + futures_util::pin_mut!(stream); + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let result = if return_logprob && prefill_logprobs.is_some() { + // Try to merge logprobs + Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk) + .unwrap_or(chunk) + } else { + chunk + }; + + if tx.send(Ok(result)).is_err() { + break; + } + } + Err(e) => { + if let Some(ref url) = decode_url { + error!("Stream error from decode server {}: {}", url, e); + RouterMetrics::record_pd_stream_error(url); + } + let _ = tx.send(Err(format!("Stream error: {}", e))); + break; + } + } + } + }); + + let stream = UnboundedReceiverStream::new(rx); + let body = Body::from_stream(stream); + + let mut response = Response::new(body); + *response.status_mut() = status; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response + } + + // Helper to process non-streaming decode response with logprob merging + async fn process_non_streaming_response( + &self, + res: reqwest::Response, + status: StatusCode, + return_logprob: bool, + prefill_body: Option, + ) -> Response { + match res.bytes().await { + Ok(decode_body) => { + if return_logprob && prefill_body.is_some() { + // Merge logprobs from prefill and decode + let prefill_body = prefill_body.as_ref().unwrap(); + match ( + serde_json::from_slice::(prefill_body), + serde_json::from_slice::(&decode_body), + ) { + (Ok(prefill_json), Ok(mut decode_json)) => { + // Use helper to merge logprobs + Self::merge_logprobs_in_json(&prefill_json, &mut decode_json); + + // Return merged response + match serde_json::to_vec(&decode_json) { + Ok(body) => (status, body).into_response(), + Err(e) => { + error!("Failed to serialize merged response: {}", e); + (status, decode_body).into_response() + } + } + } + _ => { + // If parsing fails, just return decode response + warn!("Failed to parse responses for logprob merging"); + (status, decode_body).into_response() + } + } + } else { + (status, decode_body).into_response() + } + } + Err(e) => { + error!("Failed to read decode response: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response").into_response() + } + } + } + + // Helper to process prefill response and extract body if needed for logprobs + async fn process_prefill_response( + &self, + prefill_result: Result, + prefill_url: &str, + return_logprob: bool, + ) -> Result<(StatusCode, Option), Response> { + // Check prefill result first - it's critical for disaggregated mode + let prefill_response = match prefill_result { + Ok(response) => response, + Err(e) => { + RouterMetrics::record_pd_prefill_error(prefill_url); + error!( + "Prefill server failed (CRITICAL) prefill_url={} error={}. Decode will timeout without prefill KV cache.", + prefill_url, + e + ); + + // Return error immediately - don't wait for decode to timeout + return Err(( + StatusCode::BAD_GATEWAY, + format!( + "Prefill server error: {}. This will cause decode timeout.", + e + ), + ) + .into_response()); + } + }; + + let prefill_status = StatusCode::from_u16(prefill_response.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + // Check if prefill succeeded + if !prefill_status.is_success() { + RouterMetrics::record_pd_prefill_error(prefill_url); + + // Get error body from prefill + let error_msg = prefill_response + .text() + .await + .unwrap_or_else(|_| "Unknown prefill error".to_string()); + + error!( + "Prefill server returned error status prefill_url={} status={} body={}", + prefill_url, prefill_status, error_msg + ); + + return Err(( + prefill_status, + format!("Prefill server error ({}): {}", prefill_status, error_msg), + ) + .into_response()); + } + + // Read prefill body if needed for logprob merging + let prefill_body = if return_logprob { + match prefill_response.bytes().await { + Ok(body) => Some(body), + Err(e) => { + warn!("Failed to read prefill response body for logprobs: {}", e); + None + } + } + } else { + // For non-logprob requests, just consume the response without storing + debug!("Consuming prefill response body (non-logprob request)"); + match prefill_response.bytes().await { + Ok(_) => debug!("Prefill response consumed successfully"), + Err(e) => warn!("Error consuming prefill response: {}", e), + } + None + }; + + Ok((prefill_status, prefill_body)) + } + + // Helper to build a request with headers copied from the original request + fn build_request_with_headers( + &self, + url: &str, + route: &str, + json_request: &Value, + headers: Option<&HeaderMap>, + ) -> reqwest::RequestBuilder { + let mut request = self.client.post(api_path(url, route)).json(json_request); + + // Copy headers from original request (excluding content-type and content-length which are set by .json()) + if let Some(headers) = headers { + for (name, value) in headers.iter() { + let name_str = name.as_str(); + if name_str != "content-type" && name_str != "content-length" { + // Skip headers with non-ASCII values + if value.to_str().is_ok() { + request = request.header(name, value); + } + } + } + } + + request + } + + // Helper to merge logprobs from prefill and decode responses + fn merge_logprobs_in_json(prefill_json: &Value, decode_json: &mut Value) -> bool { + if let (Some(prefill_meta), Some(decode_meta)) = ( + prefill_json.get("meta_info"), + decode_json.get_mut("meta_info"), + ) { + if let (Some(prefill_logprobs), Some(decode_logprobs)) = ( + prefill_meta.get("input_token_logprobs"), + decode_meta.get_mut("input_token_logprobs"), + ) { + if let (Some(prefill_arr), Some(decode_arr)) = + (prefill_logprobs.as_array(), decode_logprobs.as_array_mut()) + { + let mut merged = prefill_arr.clone(); + merged.extend(decode_arr.clone()); + decode_meta["input_token_logprobs"] = Value::Array(merged); + return true; + } + } + } + false + } + // Simple helper to merge logprobs in streaming responses fn merge_streaming_logprobs( prefill_logprobs: Option, @@ -1316,7 +1465,6 @@ impl PDRouter { use crate::routers::{RouterTrait, WorkerManagement}; use async_trait::async_trait; -use reqwest::Client; #[async_trait] impl WorkerManagement for PDRouter { @@ -1558,6 +1706,7 @@ mod tests { worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), load_monitor_handle: None, client: Client::new(), + retry_config: RetryConfig::default(), _prefill_health_checker: None, _decode_health_checker: None, } diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 1a6ddeea4..933728a4f 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,3 +1,4 @@ +use crate::config::types::RetryConfig; use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; @@ -11,6 +12,7 @@ use axum::{ Json, }; use futures_util::StreamExt; +use reqwest::Client; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::thread; @@ -39,6 +41,7 @@ pub struct Router { interval_secs: u64, dp_aware: bool, api_key: Option, + retry_config: RetryConfig, _worker_loads: Arc>>, _load_monitor_handle: Option>>, _health_checker: Option, @@ -54,6 +57,7 @@ impl Router { interval_secs: u64, dp_aware: bool, api_key: Option, + retry_config: RetryConfig, ) -> Result { // Update active workers gauge RouterMetrics::set_active_workers(worker_urls.len()); @@ -120,6 +124,7 @@ impl Router { interval_secs, dp_aware, api_key, + retry_config, _worker_loads: worker_loads, _load_monitor_handle: load_monitor_handle, _health_checker: Some(health_checker), @@ -141,6 +146,12 @@ impl Router { timeout_secs: u64, interval_secs: u64, ) -> Result<(), String> { + if worker_urls.is_empty() { + return Err( + "Timeout waiting for workers to become healthy: no workers provided".to_string(), + ); + } + let start_time = std::time::Instant::now(); let sync_client = reqwest::blocking::Client::builder() .timeout(Duration::from_secs(timeout_secs)) @@ -365,11 +376,13 @@ impl Router { ) -> Response { // Handle retries like the original implementation let start = Instant::now(); - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; + // Use retry config for per-worker retries + let max_request_retries = self.retry_config.max_retries; + // Total retries across all workers (2x to allow trying multiple workers) + let max_total_retries = self.retry_config.max_retries * 2; let mut total_retries = 0; - while total_retries < MAX_TOTAL_RETRIES { + while total_retries < max_total_retries { // Extract routing text directly from typed request let text = typed_req.extract_text_for_routing(); let is_stream = typed_req.is_stream(); @@ -379,7 +392,7 @@ impl Router { let mut request_retries = 0; // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { + while request_retries < max_request_retries { if total_retries >= 1 { info!("Retrying request after {} failed attempts", total_retries); RouterMetrics::record_retry(route); @@ -429,13 +442,13 @@ impl Router { route, worker_url, request_retries + 1, - MAX_REQUEST_RETRIES + max_request_retries ); request_retries += 1; total_retries += 1; - if request_retries == MAX_REQUEST_RETRIES { + if request_retries == max_request_retries { warn!( "Removing failed worker after typed request failures worker_url={}", worker_url @@ -1003,7 +1016,6 @@ impl Router { } use async_trait::async_trait; -use reqwest::Client; #[async_trait] impl WorkerManagement for Router { @@ -1210,6 +1222,7 @@ mod tests { dp_aware: false, api_key: None, client: Client::new(), + retry_config: RetryConfig::default(), _worker_loads: Arc::new(rx), _load_monitor_handle: None, _health_checker: None, @@ -1237,8 +1250,10 @@ mod tests { #[test] fn test_wait_for_healthy_workers_empty_list() { + // Empty list will timeout as there are no workers to check let result = Router::wait_for_healthy_workers(&[], 1, 1); - assert!(result.is_ok()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Timeout")); } #[test] diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 482952bf7..32c14d868 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -580,8 +580,17 @@ mod tests { use crate::routers::router::Router; let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); - let router = - Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap(); + let router = Router::new( + vec![], + policy, + reqwest::Client::new(), + 5, + 1, + false, + None, + crate::config::types::RetryConfig::default(), + ) + .unwrap(); Arc::new(router) as Arc } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 6beda2b7a..a4115926a 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -8,7 +8,7 @@ use axum::{ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; use tower::ServiceExt; @@ -44,6 +44,7 @@ impl TestContext { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; Self::new_with_config(config, worker_configs).await @@ -1085,6 +1086,7 @@ mod error_tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; let ctx = TestContext::new_with_config( @@ -1431,6 +1433,7 @@ mod pd_mode_tests { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; // Create app context @@ -1584,6 +1587,7 @@ mod request_id_tests { request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]), max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index a3cd12edb..4e9e1562d 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -3,7 +3,7 @@ mod common; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -35,6 +35,7 @@ impl TestContext { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 2ef2e0929..dcf0ffc93 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType use futures_util::StreamExt; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -36,6 +36,7 @@ impl TestContext { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index d0877eeb8..20b37aaa8 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -2,7 +2,7 @@ mod test_pd_routing { use rand::Rng; use serde_json::json; - use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; + use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::routers::pd_types::get_hostname; use sglang_router_rs::routers::pd_types::PDSelectionPolicy; @@ -178,6 +178,7 @@ mod test_pd_routing { request_id_headers: None, max_concurrent_requests: 64, cors_allowed_origins: vec![], + retry: RetryConfig::default(), }; // Router creation will fail due to health checks, but config should be valid