From 6f81a710f79626ef2cf56a84acaf1127c63d0d1c Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 11 Aug 2025 05:53:26 -0700 Subject: [PATCH] [pd-router] add retry and circuit breakfor for pd router (#9051) --- sgl-router/src/core/mod.rs | 2 +- sgl-router/src/core/retry.rs | 18 +- sgl-router/src/routers/pd_router.rs | 365 ++++++++++++++++------------ sgl-router/src/routers/router.rs | 28 +-- 4 files changed, 236 insertions(+), 177 deletions(-) diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index ebc6f2f2c..727cf3515 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -16,7 +16,7 @@ pub use circuit_breaker::{ CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState, }; pub use error::{WorkerError, WorkerResult}; -pub use retry::{BackoffCalculator, RetryError, RetryExecutor}; +pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use worker::{ start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, diff --git a/sgl-router/src/core/retry.rs b/sgl-router/src/core/retry.rs index e7a397540..bae0568be 100644 --- a/sgl-router/src/core/retry.rs +++ b/sgl-router/src/core/retry.rs @@ -1,9 +1,23 @@ use crate::config::types::RetryConfig; +use axum::http::StatusCode; use axum::response::Response; use rand::Rng; use std::time::Duration; use tracing::debug; +/// Check if an HTTP status code indicates a retryable error +pub fn is_retryable_status(status: StatusCode) -> bool { + matches!( + status, + StatusCode::REQUEST_TIMEOUT + | StatusCode::TOO_MANY_REQUESTS + | StatusCode::INTERNAL_SERVER_ERROR + | StatusCode::BAD_GATEWAY + | StatusCode::SERVICE_UNAVAILABLE + | StatusCode::GATEWAY_TIMEOUT + ) +} + /// Computes exponential backoff with optional jitter. #[derive(Debug, Clone)] pub struct BackoffCalculator; @@ -21,8 +35,8 @@ impl BackoffCalculator { // Apply jitter in range [-j, +j] let jitter = config.jitter_factor.max(0.0).min(1.0); if jitter > 0.0 { - let mut rng = rand::thread_rng(); - let jitter_scale: f32 = rng.gen_range(-jitter..=jitter); + let mut rng = rand::rng(); + let jitter_scale: f32 = rng.random_range(-jitter..=jitter); let jitter_ms = (delay_ms as f32 * jitter_scale) .round() .max(-(delay_ms as f32)); diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index c3165c31a..cd36bb5cc 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -2,7 +2,10 @@ // This module handles routing for disaggregated prefill-decode systems use super::pd_types::{api_path, PDRouterError}; use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; -use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; +use crate::core::{ + is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory, + WorkerLoadGuard, +}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; @@ -17,6 +20,7 @@ use axum::{ }; use futures_util::StreamExt; use reqwest::Client; +use serde::Serialize; use serde_json::Value; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -43,6 +47,16 @@ pub struct PDRouter { _decode_health_checker: Option, } +// Request context for PD router operations +#[derive(Clone)] +struct PDRequestContext { + route: &'static str, + batch_size: Option, + is_stream: bool, + return_logprob: bool, + request_text: Option, +} + impl PDRouter { // Dynamic worker management methods for service discovery @@ -218,12 +232,8 @@ impl PDRouter { let core_cb_config = CircuitBreakerConfig { failure_threshold: circuit_breaker_config.failure_threshold, success_threshold: circuit_breaker_config.success_threshold, - timeout_duration: std::time::Duration::from_secs( - circuit_breaker_config.timeout_duration_secs, - ), - window_duration: std::time::Duration::from_secs( - circuit_breaker_config.window_duration_secs, - ), + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; // Convert URLs to Worker trait objects @@ -459,8 +469,96 @@ impl PDRouter { Ok(original) } - // Execute the dual dispatch to prefill and decode servers - async fn execute_dual_dispatch( + // Execute the dual dispatch to prefill and decode servers with retries and bootstrap injection + async fn execute_dual_dispatch( + &self, + headers: Option<&HeaderMap>, + original_request: &T, + context: PDRequestContext, + ) -> Response { + let start_time = Instant::now(); + + let route = context.route; + RetryExecutor::execute_response_with_retry( + &self.retry_config, + // Operation per attempt + { + let original_request = original_request.clone(); + move |attempt: u32| { + let original_request = original_request.clone(); + let context = context.clone(); + async move { + // Select workers fresh for each attempt + let (prefill, decode) = + match self.select_pd_pair(context.request_text.as_deref()).await { + Ok(pair) => pair, + Err(e) => { + RouterMetrics::record_pd_error("server_selection"); + return Self::handle_server_selection_error(e); + } + }; + + debug!( + "PD retry attempt {} using prefill={} decode={}", + attempt, + prefill.url(), + decode.url() + ); + + // Serialize the original request + let mut json_request = match serde_json::to_value(&original_request) { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + + // Inject bootstrap based on current prefill worker + json_request = match Self::inject_bootstrap_into_value( + json_request, + prefill.as_ref(), + context.batch_size, + ) { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + + // Execute the actual dual dispatch + let response = self + .execute_dual_dispatch_internal( + headers, + json_request, + context.route, + prefill.as_ref(), + decode.as_ref(), + context.is_stream, + context.return_logprob, + start_time, + ) + .await; + + // Record outcomes for circuit breakers + let is_success = response.status().is_success(); + prefill.record_outcome(is_success); + decode.record_outcome(is_success); + + response + } + } + }, + // Should retry predicate + |res, _attempt| is_retryable_status(res.status()), + // On backoff hook + |delay, attempt| { + RouterMetrics::record_retry(route); + RouterMetrics::record_retry_backoff_duration(delay, attempt); + }, + // On exhausted hook + || RouterMetrics::record_retries_exhausted(route), + ) + .await + } + + // Internal method that performs the actual dual dispatch (without retry logic) + async fn execute_dual_dispatch_internal( &self, headers: Option<&HeaderMap>, json_request: Value, @@ -696,7 +794,7 @@ impl PDRouter { self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text() } - // Select a pair of prefill and decode servers + // Select a pair of prefill and decode servers considering circuit breaker state async fn select_pd_pair( &self, request_text: Option<&str>, @@ -711,31 +809,60 @@ impl PDRouter { .read() .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?; - // Check we have workers - if prefill_workers.is_empty() { - return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); - } - if decode_workers.is_empty() { - return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); - } + // Select workers using helper function + let prefill = Self::pick_worker_by_policy( + &*prefill_workers, + &*self.prefill_policy, + request_text, + "prefill", + )?; - // Select prefill worker using prefill policy - let prefill_idx = self - .prefill_policy - .select_worker(&prefill_workers, request_text) - .ok_or("Failed to select prefill worker")?; + let decode = Self::pick_worker_by_policy( + &*decode_workers, + &*self.decode_policy, + request_text, + "decode", + )?; - // Select decode worker using decode policy - let decode_idx = self - .decode_policy - .select_worker(&decode_workers, request_text) - .ok_or("Failed to select decode worker")?; - - let prefill = prefill_workers[prefill_idx].clone_worker(); - let decode = decode_workers[decode_idx].clone_worker(); Ok((prefill, decode)) } + // Helper function to select a worker using the policy + fn pick_worker_by_policy( + workers: &[Box], + policy: &dyn LoadBalancingPolicy, + request_text: Option<&str>, + worker_type: &str, + ) -> Result, String> { + // Check if we have any workers + if workers.is_empty() { + return Err(format!( + "No {} workers available. Please check if {} servers are configured and healthy.", + worker_type, worker_type + )); + } + + // Filter available workers (healthy + circuit breaker not open) + let available_workers: Vec> = workers + .iter() + .filter(|w| w.is_available()) + .map(|w| w.clone_worker()) + .collect(); + + if available_workers.is_empty() { + return Err(format!( + "No available {} workers (all circuits open or unhealthy)", + worker_type + )); + } + + // Let policy select from available workers only + match policy.select_worker(&available_workers, request_text) { + Some(idx) => Ok(available_workers[idx].clone_worker()), + None => Err(format!("Policy could not select a {} worker", worker_type)), + } + } + // Background task to monitor worker loads with shared client async fn monitor_worker_loads_with_client( worker_urls: Vec, @@ -1449,61 +1576,41 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &GenerateRequest, ) -> Response { - let start = Instant::now(); - - // Extract flags for routing logic + // Extract parameters let is_stream = body.stream; let return_logprob = body.return_logprob; - // Extract text for cache-aware routing only if needed + // Extract text for cache-aware routing let request_text = if self.policies_need_request_text() { - body.text.as_deref().or_else(|| { - body.prompt.as_ref().and_then(|p| match p { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(v) => { - v.first().map(|s| s.as_str()) - } + body.text + .as_deref() + .or_else(|| { + body.prompt.as_ref().and_then(|p| match p { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(v) => { + v.first().map(|s| s.as_str()) + } + }) }) - }) + .map(|s| s.to_string()) } else { None }; - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route=/generate prefill_url={} decode_url={}", - prefill.url(), - decode.url() - ); - + // Calculate batch size let batch_size = Self::get_generate_batch_size(body); - let original = match serde_json::to_value(body) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json, - "/generate", - prefill.as_ref(), - decode.as_ref(), + // Create context + let context = PDRequestContext { + route: "/generate", + batch_size, is_stream, return_logprob, - start, - ) - .await + request_text, + }; + + // Execute with retry and bootstrap injection + self.execute_dual_dispatch(headers, body, context).await } async fn route_chat( @@ -1511,25 +1618,19 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response { - let start = Instant::now(); - - // Extract flags for routing logic + // Extract parameters let is_stream = body.stream; let return_logprob = body.logprobs; - // Extract text for cache-aware routing from chat messages only if needed + // Extract text for cache-aware routing let request_text = if self.policies_need_request_text() { body.messages.first().and_then(|msg| match msg { - crate::openai_api_types::ChatMessage::User { content, .. } => { - match content { - crate::openai_api_types::UserMessageContent::Text(text) => { - Some(text.as_str()) - } - crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content - } - } + crate::openai_api_types::ChatMessage::User { content, .. } => match content { + crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()), + crate::openai_api_types::UserMessageContent::Parts(_) => None, + }, crate::openai_api_types::ChatMessage::System { content, .. } => { - Some(content.as_str()) + Some(content.clone()) } _ => None, }) @@ -1537,41 +1638,20 @@ impl RouterTrait for PDRouter { None }; - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}", - prefill.url(), - decode.url() - ); - + // Calculate batch size let batch_size = Self::get_chat_batch_size(body); - let original = match serde_json::to_value(body) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json, - "/v1/chat/completions", - prefill.as_ref(), - decode.as_ref(), + // Create context + let context = PDRequestContext { + route: "/v1/chat/completions", + batch_size, is_stream, return_logprob, - start, - ) - .await + request_text, + }; + + // Execute with retry and bootstrap injection + self.execute_dual_dispatch(headers, body, context).await } async fn route_completion( @@ -1579,57 +1659,36 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response { - let start = Instant::now(); - - // Extract flags for routing logic + // Extract parameters let is_stream = body.stream; let return_logprob = body.logprobs.is_some(); - // Extract text for cache-aware routing only if needed + // Extract text for cache-aware routing let request_text = if self.policies_need_request_text() { match &body.prompt { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()), + crate::openai_api_types::StringOrArray::Array(v) => { + v.first().map(|s| s.to_string()) + } } } else { None }; - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route=/v1/completions prefill_url={} decode_url={}", - prefill.url(), - decode.url() - ); - + // Calculate batch size let batch_size = Self::get_completion_batch_size(body); - let original = match serde_json::to_value(body) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { - Ok(v) => v, - Err(e) => return Self::handle_serialization_error(e), - }; - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json, - "/v1/completions", - prefill.as_ref(), - decode.as_ref(), + // Create context + let context = PDRequestContext { + route: "/v1/completions", + batch_size, is_stream, return_logprob, - start, - ) - .await + request_text, + }; + + // Execute with retry and bootstrap injection + self.execute_dual_dispatch(headers, body, context).await } async fn flush_cache(&self) -> Response { diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 7a5d54685..d6ecb0960 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,5 +1,7 @@ use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; -use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory}; +use crate::core::{ + is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory, +}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; @@ -81,12 +83,8 @@ impl Router { let core_cb_config = CircuitBreakerConfig { failure_threshold: circuit_breaker_config.failure_threshold, success_threshold: circuit_breaker_config.success_threshold, - timeout_duration: std::time::Duration::from_secs( - circuit_breaker_config.timeout_duration_secs, - ), - window_duration: std::time::Duration::from_secs( - circuit_breaker_config.window_duration_secs, - ), + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; // Create Worker trait objects from URLs @@ -397,18 +395,6 @@ impl Router { Some(available[idx].clone_worker()) } - fn is_retryable_status(status: StatusCode) -> bool { - matches!( - status, - StatusCode::REQUEST_TIMEOUT - | StatusCode::TOO_MANY_REQUESTS - | StatusCode::INTERNAL_SERVER_ERROR - | StatusCode::BAD_GATEWAY - | StatusCode::SERVICE_UNAVAILABLE - | StatusCode::GATEWAY_TIMEOUT - ) - } - pub async fn route_typed_request< T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, >( @@ -461,7 +447,7 @@ impl Router { response }, // should_retry predicate - |res, _attempt| Self::is_retryable_status(res.status()), + |res, _attempt| is_retryable_status(res.status()), // on_backoff hook |delay, attempt| { RouterMetrics::record_retry(route); @@ -476,7 +462,7 @@ impl Router { let duration = start.elapsed(); RouterMetrics::record_request(route); RouterMetrics::record_generate_duration(duration); - } else if !Self::is_retryable_status(response.status()) { + } else if !is_retryable_status(response.status()) { RouterMetrics::record_request_error(route, "non_retryable_error"); }