From 61a4680494902efeef00ce7559dca3869bda3881 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 8 Aug 2025 09:20:22 -0700 Subject: [PATCH] [router] router circuit breaker core (#8941) --- sgl-router/src/config/types.rs | 31 ++ sgl-router/src/core/circuit_breaker.rs | 545 +++++++++++++++++++++++ sgl-router/src/core/mod.rs | 5 + sgl-router/src/core/worker.rs | 153 ++++++- sgl-router/src/lib.rs | 1 + sgl-router/src/routers/factory.rs | 2 + sgl-router/src/routers/pd_router.rs | 41 +- sgl-router/src/routers/router.rs | 34 +- sgl-router/src/service_discovery.rs | 1 + sgl-router/tests/api_endpoints_test.rs | 8 +- sgl-router/tests/request_formats_test.rs | 5 +- sgl-router/tests/streaming_tests.rs | 5 +- sgl-router/tests/test_pd_routing.rs | 5 +- 13 files changed, 818 insertions(+), 18 deletions(-) create mode 100644 sgl-router/src/core/circuit_breaker.rs diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index a52e124ad..c72981b5f 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -41,6 +41,8 @@ pub struct RouterConfig { pub cors_allowed_origins: Vec, /// Retry configuration pub retry: RetryConfig, + /// Circuit breaker configuration + pub circuit_breaker: CircuitBreakerConfig, } /// Routing mode configuration @@ -208,6 +210,30 @@ impl Default for RetryConfig { } } +/// Circuit breaker configuration for worker reliability +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CircuitBreakerConfig { + /// Number of consecutive failures before opening circuit + pub failure_threshold: u32, + /// Number of consecutive successes before closing circuit + pub success_threshold: u32, + /// Time before attempting to recover from open state (in seconds) + pub timeout_duration_secs: u64, + /// Window duration for failure tracking (in seconds) + pub window_duration_secs: u64, +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + failure_threshold: 5, + success_threshold: 2, + timeout_duration_secs: 30, + window_duration_secs: 60, + } + } +} + /// Metrics configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetricsConfig { @@ -249,6 +275,7 @@ impl Default for RouterConfig { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), } } } @@ -360,6 +387,7 @@ mod tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; let json = serde_json::to_string(&config).unwrap(); @@ -788,6 +816,7 @@ mod tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; assert!(config.mode.is_pd_mode()); @@ -840,6 +869,7 @@ mod tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; assert!(!config.mode.is_pd_mode()); @@ -888,6 +918,7 @@ mod tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/core/circuit_breaker.rs b/sgl-router/src/core/circuit_breaker.rs new file mode 100644 index 000000000..037f4192b --- /dev/null +++ b/sgl-router/src/core/circuit_breaker.rs @@ -0,0 +1,545 @@ +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, Instant}; + +/// Circuit breaker configuration +#[derive(Debug, Clone)] +pub struct CircuitBreakerConfig { + /// Number of consecutive failures to open the circuit + pub failure_threshold: u32, + /// Success threshold to close circuit from half-open + pub success_threshold: u32, + /// Duration to wait before attempting half-open + pub timeout_duration: Duration, + /// Time window for failure counting + pub window_duration: Duration, +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + failure_threshold: 5, + success_threshold: 2, + timeout_duration: Duration::from_secs(30), + window_duration: Duration::from_secs(60), + } + } +} + +/// Circuit breaker state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CircuitState { + /// Normal operation - requests are allowed + Closed, + /// Circuit is open - requests are rejected + Open, + /// Testing if service has recovered - limited requests allowed + HalfOpen, +} + +impl std::fmt::Display for CircuitState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CircuitState::Closed => write!(f, "Closed"), + CircuitState::Open => write!(f, "Open"), + CircuitState::HalfOpen => write!(f, "HalfOpen"), + } + } +} + +/// Circuit breaker implementation +#[derive(Debug)] +pub struct CircuitBreaker { + state: Arc>, + consecutive_failures: Arc, + consecutive_successes: Arc, + total_failures: Arc, + total_successes: Arc, + last_failure_time: Arc>>, + last_state_change: Arc>, + config: CircuitBreakerConfig, +} + +impl CircuitBreaker { + /// Create a new circuit breaker with default configuration + pub fn new() -> Self { + Self::with_config(CircuitBreakerConfig::default()) + } + + /// Create a new circuit breaker with custom configuration + pub fn with_config(config: CircuitBreakerConfig) -> Self { + Self { + state: Arc::new(RwLock::new(CircuitState::Closed)), + consecutive_failures: Arc::new(AtomicU32::new(0)), + consecutive_successes: Arc::new(AtomicU32::new(0)), + total_failures: Arc::new(AtomicU64::new(0)), + total_successes: Arc::new(AtomicU64::new(0)), + last_failure_time: Arc::new(RwLock::new(None)), + last_state_change: Arc::new(RwLock::new(Instant::now())), + config, + } + } + + /// Check if a request can be executed + pub fn can_execute(&self) -> bool { + // First check if we need to transition from Open to HalfOpen + self.check_and_update_state(); + + let state = *self.state.read().unwrap(); + match state { + CircuitState::Closed => true, + CircuitState::Open => false, + CircuitState::HalfOpen => true, // Allow limited requests in half-open state + } + } + + /// Get the current state + pub fn state(&self) -> CircuitState { + self.check_and_update_state(); + *self.state.read().unwrap() + } + + /// Record the outcome of a request + pub fn record_outcome(&self, success: bool) { + if success { + self.record_success(); + } else { + self.record_failure(); + } + } + + /// Record a successful request + pub fn record_success(&self) { + self.total_successes.fetch_add(1, Ordering::Relaxed); + self.consecutive_failures.store(0, Ordering::Release); + let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1; + + let current_state = *self.state.read().unwrap(); + + match current_state { + CircuitState::HalfOpen => { + // Check if we've reached the success threshold to close the circuit + if successes >= self.config.success_threshold { + self.transition_to(CircuitState::Closed); + } + } + CircuitState::Closed => { + // Already closed, nothing to do + } + CircuitState::Open => { + // Shouldn't happen, but if it does, stay open + tracing::warn!("Success recorded while circuit is open"); + } + } + } + + /// Record a failed request + pub fn record_failure(&self) { + self.total_failures.fetch_add(1, Ordering::Relaxed); + self.consecutive_successes.store(0, Ordering::Release); + let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1; + + // Update last failure time + { + let mut last_failure = self.last_failure_time.write().unwrap(); + *last_failure = Some(Instant::now()); + } + + let current_state = *self.state.read().unwrap(); + + match current_state { + CircuitState::Closed => { + // Check if we've reached the failure threshold to open the circuit + if failures >= self.config.failure_threshold { + self.transition_to(CircuitState::Open); + } + } + CircuitState::HalfOpen => { + // Single failure in half-open state reopens the circuit + self.transition_to(CircuitState::Open); + } + CircuitState::Open => { + // Already open, nothing to do + } + } + } + + /// Check and update state based on timeout + fn check_and_update_state(&self) { + let current_state = *self.state.read().unwrap(); + + if current_state == CircuitState::Open { + // Check if timeout has expired + let last_change = *self.last_state_change.read().unwrap(); + if last_change.elapsed() >= self.config.timeout_duration { + self.transition_to(CircuitState::HalfOpen); + } + } + } + + /// Transition to a new state + fn transition_to(&self, new_state: CircuitState) { + let mut state = self.state.write().unwrap(); + let old_state = *state; + + if old_state != new_state { + *state = new_state; + + // Update last state change time + let mut last_change = self.last_state_change.write().unwrap(); + *last_change = Instant::now(); + + // Reset counters based on transition + match new_state { + CircuitState::Closed => { + self.consecutive_failures.store(0, Ordering::Release); + self.consecutive_successes.store(0, Ordering::Release); + } + CircuitState::Open => { + self.consecutive_successes.store(0, Ordering::Release); + } + CircuitState::HalfOpen => { + self.consecutive_failures.store(0, Ordering::Release); + self.consecutive_successes.store(0, Ordering::Release); + } + } + + tracing::info!( + "Circuit breaker state transition: {} -> {}", + old_state, + new_state + ); + } + } + + /// Get the number of consecutive failures + pub fn failure_count(&self) -> u32 { + self.consecutive_failures.load(Ordering::Acquire) + } + + /// Get the number of consecutive successes + pub fn success_count(&self) -> u32 { + self.consecutive_successes.load(Ordering::Acquire) + } + + /// Get total failures + pub fn total_failures(&self) -> u64 { + self.total_failures.load(Ordering::Relaxed) + } + + /// Get total successes + pub fn total_successes(&self) -> u64 { + self.total_successes.load(Ordering::Relaxed) + } + + /// Get time since last failure + pub fn time_since_last_failure(&self) -> Option { + self.last_failure_time.read().unwrap().map(|t| t.elapsed()) + } + + /// Get time since last state change + pub fn time_since_last_state_change(&self) -> Duration { + self.last_state_change.read().unwrap().elapsed() + } + + /// Check if the circuit is in a half-open state + pub fn is_half_open(&self) -> bool { + self.state() == CircuitState::HalfOpen + } + + /// Record a test success (for health check probing) + pub fn record_test_success(&self) { + if self.is_half_open() { + self.record_success(); + } + } + + /// Record a test failure (for health check probing) + pub fn record_test_failure(&self) { + if self.is_half_open() { + self.record_failure(); + } + } + + /// Reset the circuit breaker to closed state + pub fn reset(&self) { + self.transition_to(CircuitState::Closed); + self.consecutive_failures.store(0, Ordering::Release); + self.consecutive_successes.store(0, Ordering::Release); + } + + /// Force the circuit to open (for manual intervention) + pub fn force_open(&self) { + self.transition_to(CircuitState::Open); + } + + /// Get circuit breaker statistics + pub fn stats(&self) -> CircuitBreakerStats { + CircuitBreakerStats { + state: self.state(), + consecutive_failures: self.failure_count(), + consecutive_successes: self.success_count(), + total_failures: self.total_failures(), + total_successes: self.total_successes(), + time_since_last_failure: self.time_since_last_failure(), + time_since_last_state_change: self.time_since_last_state_change(), + } + } +} + +impl Clone for CircuitBreaker { + fn clone(&self) -> Self { + Self { + state: Arc::clone(&self.state), + consecutive_failures: Arc::clone(&self.consecutive_failures), + consecutive_successes: Arc::clone(&self.consecutive_successes), + total_failures: Arc::clone(&self.total_failures), + total_successes: Arc::clone(&self.total_successes), + last_failure_time: Arc::clone(&self.last_failure_time), + last_state_change: Arc::clone(&self.last_state_change), + config: self.config.clone(), + } + } +} + +impl Default for CircuitBreaker { + fn default() -> Self { + Self::new() + } +} + +/// Circuit breaker statistics +#[derive(Debug, Clone)] +pub struct CircuitBreakerStats { + pub state: CircuitState, + pub consecutive_failures: u32, + pub consecutive_successes: u32, + pub total_failures: u64, + pub total_successes: u64, + pub time_since_last_failure: Option, + pub time_since_last_state_change: Duration, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + #[test] + fn test_circuit_breaker_initial_state() { + let cb = CircuitBreaker::new(); + assert_eq!(cb.state(), CircuitState::Closed); + assert!(cb.can_execute()); + assert_eq!(cb.failure_count(), 0); + assert_eq!(cb.success_count(), 0); + } + + #[test] + fn test_circuit_opens_on_threshold() { + let config = CircuitBreakerConfig { + failure_threshold: 3, + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Record failures up to threshold + assert_eq!(cb.state(), CircuitState::Closed); + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Closed); + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Closed); + cb.record_failure(); + + // Circuit should now be open + assert_eq!(cb.state(), CircuitState::Open); + assert!(!cb.can_execute()); + assert_eq!(cb.failure_count(), 3); + } + + #[test] + fn test_circuit_half_open_after_timeout() { + let config = CircuitBreakerConfig { + failure_threshold: 1, + timeout_duration: Duration::from_millis(100), + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Open the circuit + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Open); + + // Wait for timeout + thread::sleep(Duration::from_millis(150)); + + // Circuit should be half-open + assert_eq!(cb.state(), CircuitState::HalfOpen); + assert!(cb.can_execute()); + } + + #[test] + fn test_circuit_closes_on_success_threshold() { + let config = CircuitBreakerConfig { + failure_threshold: 1, + success_threshold: 2, + timeout_duration: Duration::from_millis(50), + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Open the circuit + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Open); + + // Wait for timeout + thread::sleep(Duration::from_millis(100)); + assert_eq!(cb.state(), CircuitState::HalfOpen); + + // Record successes + cb.record_success(); + assert_eq!(cb.state(), CircuitState::HalfOpen); + cb.record_success(); + + // Circuit should now be closed + assert_eq!(cb.state(), CircuitState::Closed); + assert!(cb.can_execute()); + } + + #[test] + fn test_circuit_reopens_on_half_open_failure() { + let config = CircuitBreakerConfig { + failure_threshold: 1, + timeout_duration: Duration::from_millis(50), + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Open the circuit + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Open); + + // Wait for timeout + thread::sleep(Duration::from_millis(100)); + assert_eq!(cb.state(), CircuitState::HalfOpen); + + // Record a failure in half-open state + cb.record_failure(); + + // Circuit should reopen immediately + assert_eq!(cb.state(), CircuitState::Open); + assert!(!cb.can_execute()); + } + + #[test] + fn test_success_resets_failure_count() { + let config = CircuitBreakerConfig { + failure_threshold: 3, + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Record some failures + cb.record_failure(); + cb.record_failure(); + assert_eq!(cb.failure_count(), 2); + + // Success should reset failure count + cb.record_success(); + assert_eq!(cb.failure_count(), 0); + assert_eq!(cb.success_count(), 1); + + // Can now record more failures without opening + cb.record_failure(); + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Closed); + } + + #[test] + fn test_manual_reset() { + let config = CircuitBreakerConfig { + failure_threshold: 1, + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + // Open the circuit + cb.record_failure(); + assert_eq!(cb.state(), CircuitState::Open); + + // Manual reset + cb.reset(); + assert_eq!(cb.state(), CircuitState::Closed); + assert_eq!(cb.failure_count(), 0); + assert_eq!(cb.success_count(), 0); + } + + #[test] + fn test_force_open() { + let cb = CircuitBreaker::new(); + assert_eq!(cb.state(), CircuitState::Closed); + + cb.force_open(); + assert_eq!(cb.state(), CircuitState::Open); + assert!(!cb.can_execute()); + } + + #[test] + fn test_stats() { + let config = CircuitBreakerConfig { + failure_threshold: 2, + ..Default::default() + }; + let cb = CircuitBreaker::with_config(config); + + cb.record_success(); + cb.record_failure(); + cb.record_failure(); + + let stats = cb.stats(); + assert_eq!(stats.state, CircuitState::Open); + assert_eq!(stats.consecutive_failures, 2); + assert_eq!(stats.consecutive_successes, 0); + assert_eq!(stats.total_failures, 2); + assert_eq!(stats.total_successes, 1); + } + + #[test] + fn test_clone() { + let cb1 = CircuitBreaker::new(); + cb1.record_failure(); + + let cb2 = cb1.clone(); + assert_eq!(cb2.failure_count(), 1); + + // Changes to cb1 affect cb2 (shared state) + cb1.record_failure(); + assert_eq!(cb2.failure_count(), 2); + } + + #[test] + fn test_thread_safety() { + use std::sync::Arc; + + let cb = Arc::new(CircuitBreaker::new()); + let mut handles = vec![]; + + // Spawn threads that record failures + for _ in 0..10 { + let cb_clone = Arc::clone(&cb); + let handle = thread::spawn(move || { + for _ in 0..100 { + cb_clone.record_failure(); + } + }); + handles.push(handle); + } + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + + // Should have recorded 1000 failures + assert_eq!(cb.total_failures(), 1000); + } +} diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index e344190b2..399b04b36 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -3,12 +3,17 @@ //! This module contains the fundamental types and traits used throughout the router: //! - Worker trait and implementations //! - Error types +//! - Circuit breaker for reliability //! - Common utilities +pub mod circuit_breaker; pub mod error; pub mod worker; // Re-export commonly used types at the module level +pub use circuit_breaker::{ + CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState, +}; pub use error::{WorkerError, WorkerResult}; pub use worker::{ start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection, diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 9f865fa8f..d22a69abc 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,4 +1,4 @@ -use super::{WorkerError, WorkerResult}; +use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; use async_trait::async_trait; use futures; use serde_json; @@ -66,6 +66,19 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Clone the worker (for trait objects) fn clone_worker(&self) -> Box; + /// Get the circuit breaker for this worker + fn circuit_breaker(&self) -> &CircuitBreaker; + + /// Check if the worker is available (healthy + circuit closed/half-open) + fn is_available(&self) -> bool { + self.is_healthy() && self.circuit_breaker().can_execute() + } + + /// Record the outcome of a request to this worker + fn record_outcome(&self, success: bool) { + self.circuit_breaker().record_outcome(success); + } + // === DP-aware methods === /// Check if this worker is DP-aware @@ -172,6 +185,7 @@ pub struct BasicWorker { load_counter: Arc, processed_counter: Arc, healthy: Arc, + circuit_breaker: CircuitBreaker, } impl BasicWorker { @@ -188,6 +202,7 @@ impl BasicWorker { load_counter: Arc::new(AtomicUsize::new(0)), processed_counter: Arc::new(AtomicUsize::new(0)), healthy: Arc::new(AtomicBool::new(true)), + circuit_breaker: CircuitBreaker::new(), } } @@ -201,6 +216,11 @@ impl BasicWorker { self } + pub fn with_circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self { + self.circuit_breaker = CircuitBreaker::with_config(config); + self + } + pub fn normalised_url(&self) -> WorkerResult<&str> { if self.url().contains("@") { // Need to extract the URL from "http://host:port@dp_rank" @@ -304,6 +324,10 @@ impl Worker for BasicWorker { fn clone_worker(&self) -> Box { Box::new(self.clone()) } + + fn circuit_breaker(&self) -> &CircuitBreaker { + &self.circuit_breaker + } } /// A DP-aware worker that handles data-parallel routing @@ -421,6 +445,10 @@ impl Worker for DPAwareWorker { Box::new(self.clone()) } + fn circuit_breaker(&self) -> &CircuitBreaker { + self.base_worker.circuit_breaker() + } + // DP-aware specific implementations fn is_dp_aware(&self) -> bool { @@ -469,6 +497,17 @@ impl WorkerFactory { Box::new(BasicWorker::new(url, WorkerType::Regular)) } + /// Create a regular worker with custom circuit breaker configuration + pub fn create_regular_with_config( + url: String, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + Box::new( + BasicWorker::new(url, WorkerType::Regular) + .with_circuit_breaker_config(circuit_breaker_config), + ) + } + /// Create a prefill worker with optional bootstrap port pub fn create_prefill(url: String, bootstrap_port: Option) -> Box { Box::new(BasicWorker::new( @@ -477,11 +516,34 @@ impl WorkerFactory { )) } + /// Create a prefill worker with custom circuit breaker configuration + pub fn create_prefill_with_config( + url: String, + bootstrap_port: Option, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + Box::new( + BasicWorker::new(url, WorkerType::Prefill { bootstrap_port }) + .with_circuit_breaker_config(circuit_breaker_config), + ) + } + /// Create a decode worker pub fn create_decode(url: String) -> Box { Box::new(BasicWorker::new(url, WorkerType::Decode)) } + /// Create a decode worker with custom circuit breaker configuration + pub fn create_decode_with_config( + url: String, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + Box::new( + BasicWorker::new(url, WorkerType::Decode) + .with_circuit_breaker_config(circuit_breaker_config), + ) + } + /// Create workers from URLs with automatic type detection pub fn create_from_urls( regular_urls: Vec, @@ -796,6 +858,7 @@ pub fn start_health_checker( mod tests { use super::*; use std::sync::RwLock; + use std::thread; use std::time::Duration; use tokio::time::timeout; @@ -1574,6 +1637,94 @@ mod tests { assert_eq!(workers[1].url(), "http://w2:8080"); } + // ===== Circuit Breaker Integration Tests ===== + + #[test] + fn test_worker_circuit_breaker() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial state should be available + assert!(worker.is_available()); + assert_eq!( + worker.circuit_breaker().state(), + crate::core::CircuitState::Closed + ); + + // Record some failures + worker.record_outcome(false); + worker.record_outcome(false); + + // Still available (default threshold is 5) + assert!(worker.is_available()); + + // Record more failures to open circuit + worker.record_outcome(false); + worker.record_outcome(false); + worker.record_outcome(false); + + // Circuit should be open, worker not available + assert!(!worker.is_available()); + assert!(worker.is_healthy()); // Still healthy + assert!(!worker.circuit_breaker().can_execute()); // But circuit is open + } + + #[test] + fn test_worker_with_circuit_breaker_config() { + let config = crate::core::CircuitBreakerConfig { + failure_threshold: 2, + success_threshold: 1, + timeout_duration: Duration::from_millis(100), + window_duration: Duration::from_secs(60), + }; + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) + .with_circuit_breaker_config(config); + + // Should open after 2 failures + worker.record_outcome(false); + assert!(worker.is_available()); + worker.record_outcome(false); + assert!(!worker.is_available()); + + // Wait for timeout + thread::sleep(Duration::from_millis(150)); + + // Should be half-open + assert!(worker.is_available()); + assert_eq!( + worker.circuit_breaker().state(), + crate::core::CircuitState::HalfOpen + ); + + // Success should close it + worker.record_outcome(true); + assert_eq!( + worker.circuit_breaker().state(), + crate::core::CircuitState::Closed + ); + } + + #[test] + fn test_dp_aware_worker_circuit_breaker() { + let dp_worker = + DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular); + + // Should have circuit breaker + assert!(dp_worker.is_available()); + + // Record failures + for _ in 0..5 { + dp_worker.record_outcome(false); + } + + // Should not be available + assert!(!dp_worker.is_available()); + assert_eq!( + dp_worker.circuit_breaker().state(), + crate::core::CircuitState::Open + ); + } + // ===== Integration tests ===== #[tokio::test] diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 290fbda9a..680db06a0 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -147,6 +147,7 @@ impl Router { max_concurrent_requests: self.max_concurrent_requests, cors_allowed_origins: self.cors_allowed_origins.clone(), retry: config::RetryConfig::default(), + circuit_breaker: config::CircuitBreakerConfig::default(), }) } } diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index e67ce6650..357007278 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -51,6 +51,7 @@ impl RouterFactory { ctx.router_config.dp_aware, ctx.router_config.api_key.clone(), ctx.router_config.retry.clone(), + ctx.router_config.circuit_breaker.clone(), )?; Ok(Box::new(router)) @@ -81,6 +82,7 @@ impl RouterFactory { ctx.router_config.worker_startup_timeout_secs, ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.retry.clone(), + ctx.router_config.circuit_breaker.clone(), )?; Ok(Box::new(router)) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index ab22e1d9d..1815f1bfa 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,8 +1,8 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems use super::pd_types::{api_path, PDRouterError}; -use crate::config::types::RetryConfig; -use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; +use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; +use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; @@ -41,6 +41,7 @@ pub struct PDRouter { // Dedicated client for prefill fire-and-forget (non-logprob) requests pub prefill_client: Client, pub retry_config: RetryConfig, + pub circuit_breaker_config: CircuitBreakerConfig, _prefill_health_checker: Option, _decode_health_checker: Option, } @@ -68,8 +69,12 @@ impl PDRouter { // Wait for the new server to be healthy self.wait_for_server_health(&url).await?; - // Create Worker for the new prefill server - let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port); + // Create Worker for the new prefill server with circuit breaker configuration + let worker = WorkerFactory::create_prefill_with_config( + url.clone(), + bootstrap_port, + self.circuit_breaker_config.clone(), + ); // Add to prefill workers list let mut workers = self @@ -99,8 +104,11 @@ impl PDRouter { // Wait for the new server to be healthy self.wait_for_server_health(&url).await?; - // Create Worker for the new decode server - let worker = WorkerFactory::create_decode(url.clone()); + // Create Worker for the new decode server with circuit breaker configuration + let worker = WorkerFactory::create_decode_with_config( + url.clone(), + self.circuit_breaker_config.clone(), + ); // Add to decode workers list let mut workers = self @@ -189,16 +197,31 @@ impl PDRouter { timeout_secs: u64, interval_secs: u64, retry_config: RetryConfig, + circuit_breaker_config: ConfigCircuitBreakerConfig, ) -> Result { + // Convert config CircuitBreakerConfig to core CircuitBreakerConfig + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: std::time::Duration::from_secs( + circuit_breaker_config.timeout_duration_secs, + ), + window_duration: std::time::Duration::from_secs( + circuit_breaker_config.window_duration_secs, + ), + }; + // Convert URLs to Worker trait objects let prefill_workers: Vec> = prefill_urls .into_iter() - .map(|(url, port)| WorkerFactory::create_prefill(url, port)) + .map(|(url, port)| { + WorkerFactory::create_prefill_with_config(url, port, core_cb_config.clone()) + }) .collect(); let decode_workers: Vec> = decode_urls .into_iter() - .map(WorkerFactory::create_decode) + .map(|url| WorkerFactory::create_decode_with_config(url, core_cb_config.clone())) .collect(); // Wait for PD workers to be healthy (skip if empty - for service discovery mode) @@ -280,6 +303,7 @@ impl PDRouter { client, prefill_client, retry_config, + circuit_breaker_config: core_cb_config, _prefill_health_checker: Some(prefill_health_checker), _decode_health_checker: Some(decode_health_checker), }) @@ -1848,6 +1872,7 @@ mod tests { client: Client::new(), prefill_client: Client::new(), retry_config: RetryConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), _prefill_health_checker: None, _decode_health_checker: None, } diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 933728a4f..894629b9b 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,5 +1,5 @@ -use crate::config::types::RetryConfig; -use crate::core::{HealthChecker, Worker, WorkerFactory}; +use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; +use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; @@ -42,6 +42,7 @@ pub struct Router { dp_aware: bool, api_key: Option, retry_config: RetryConfig, + circuit_breaker_config: CircuitBreakerConfig, _worker_loads: Arc>>, _load_monitor_handle: Option>>, _health_checker: Option, @@ -58,6 +59,7 @@ impl Router { dp_aware: bool, api_key: Option, retry_config: RetryConfig, + circuit_breaker_config: ConfigCircuitBreakerConfig, ) -> Result { // Update active workers gauge RouterMetrics::set_active_workers(worker_urls.len()); @@ -75,10 +77,24 @@ impl Router { worker_urls }; + // Convert config CircuitBreakerConfig to core CircuitBreakerConfig + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: std::time::Duration::from_secs( + circuit_breaker_config.timeout_duration_secs, + ), + window_duration: std::time::Duration::from_secs( + circuit_breaker_config.window_duration_secs, + ), + }; + // Create Worker trait objects from URLs let workers: Vec> = worker_urls .iter() - .map(|url| WorkerFactory::create_regular(url.clone())) + .map(|url| { + WorkerFactory::create_regular_with_config(url.clone(), core_cb_config.clone()) + }) .collect(); // Initialize policy with workers if needed (e.g., for cache-aware) @@ -125,6 +141,7 @@ impl Router { dp_aware, api_key, retry_config, + circuit_breaker_config: core_cb_config, _worker_loads: worker_loads, _load_monitor_handle: load_monitor_handle, _health_checker: Some(health_checker), @@ -752,7 +769,10 @@ impl Router { continue; } info!("Added worker: {}", dp_url); - let new_worker = WorkerFactory::create_regular(dp_url.to_string()); + let new_worker = WorkerFactory::create_regular_with_config( + dp_url.to_string(), + self.circuit_breaker_config.clone(), + ); workers_guard.push(new_worker); worker_added = true; } @@ -764,7 +784,10 @@ impl Router { return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); - let new_worker = WorkerFactory::create_regular(worker_url.to_string()); + let new_worker = WorkerFactory::create_regular_with_config( + worker_url.to_string(), + self.circuit_breaker_config.clone(), + ); workers_guard.push(new_worker); } @@ -1223,6 +1246,7 @@ mod tests { api_key: None, client: Client::new(), retry_config: RetryConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), _worker_loads: Arc::new(rx), _load_monitor_handle: None, _health_checker: None, diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 32c14d868..5a848cf98 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -589,6 +589,7 @@ mod tests { false, None, crate::config::types::RetryConfig::default(), + crate::config::types::CircuitBreakerConfig::default(), ) .unwrap(); Arc::new(router) as Arc diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index a4115926a..68b63f0b3 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -8,7 +8,9 @@ use axum::{ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{ + CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, +}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; use tower::ServiceExt; @@ -45,6 +47,7 @@ impl TestContext { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; Self::new_with_config(config, worker_configs).await @@ -1087,6 +1090,7 @@ mod error_tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; let ctx = TestContext::new_with_config( @@ -1434,6 +1438,7 @@ mod pd_mode_tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; // Create app context @@ -1588,6 +1593,7 @@ mod request_id_tests { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 4e9e1562d..7ae7ab383 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -3,7 +3,9 @@ mod common; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{ + CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, +}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -36,6 +38,7 @@ impl TestContext { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index dcf0ffc93..94abc739b 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -4,7 +4,9 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType use futures_util::StreamExt; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::config::{ + CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, +}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -37,6 +39,7 @@ impl TestContext { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 20b37aaa8..574f0e88e 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -2,7 +2,9 @@ mod test_pd_routing { use rand::Rng; use serde_json::json; - use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; + use sglang_router_rs::config::{ + CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + }; use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::routers::pd_types::get_hostname; use sglang_router_rs::routers::pd_types::PDSelectionPolicy; @@ -179,6 +181,7 @@ mod test_pd_routing { max_concurrent_requests: 64, cors_allowed_origins: vec![], retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), }; // Router creation will fail due to health checks, but config should be valid