[router] router circuit breaker core (#8941)
This commit is contained in:
@@ -41,6 +41,8 @@ pub struct RouterConfig {
|
|||||||
pub cors_allowed_origins: Vec<String>,
|
pub cors_allowed_origins: Vec<String>,
|
||||||
/// Retry configuration
|
/// Retry configuration
|
||||||
pub retry: RetryConfig,
|
pub retry: RetryConfig,
|
||||||
|
/// Circuit breaker configuration
|
||||||
|
pub circuit_breaker: CircuitBreakerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
@@ -208,6 +210,30 @@ impl Default for RetryConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Circuit breaker configuration for worker reliability
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CircuitBreakerConfig {
|
||||||
|
/// Number of consecutive failures before opening circuit
|
||||||
|
pub failure_threshold: u32,
|
||||||
|
/// Number of consecutive successes before closing circuit
|
||||||
|
pub success_threshold: u32,
|
||||||
|
/// Time before attempting to recover from open state (in seconds)
|
||||||
|
pub timeout_duration_secs: u64,
|
||||||
|
/// Window duration for failure tracking (in seconds)
|
||||||
|
pub window_duration_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CircuitBreakerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
failure_threshold: 5,
|
||||||
|
success_threshold: 2,
|
||||||
|
timeout_duration_secs: 30,
|
||||||
|
window_duration_secs: 60,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Metrics configuration
|
/// Metrics configuration
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MetricsConfig {
|
pub struct MetricsConfig {
|
||||||
@@ -249,6 +275,7 @@ impl Default for RouterConfig {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -360,6 +387,7 @@ mod tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&config).unwrap();
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
@@ -788,6 +816,7 @@ mod tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.mode.is_pd_mode());
|
assert!(config.mode.is_pd_mode());
|
||||||
@@ -840,6 +869,7 @@ mod tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!config.mode.is_pd_mode());
|
assert!(!config.mode.is_pd_mode());
|
||||||
@@ -888,6 +918,7 @@ mod tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.has_service_discovery());
|
assert!(config.has_service_discovery());
|
||||||
|
|||||||
545
sgl-router/src/core/circuit_breaker.rs
Normal file
545
sgl-router/src/core/circuit_breaker.rs
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
/// Circuit breaker configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CircuitBreakerConfig {
|
||||||
|
/// Number of consecutive failures to open the circuit
|
||||||
|
pub failure_threshold: u32,
|
||||||
|
/// Success threshold to close circuit from half-open
|
||||||
|
pub success_threshold: u32,
|
||||||
|
/// Duration to wait before attempting half-open
|
||||||
|
pub timeout_duration: Duration,
|
||||||
|
/// Time window for failure counting
|
||||||
|
pub window_duration: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CircuitBreakerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
failure_threshold: 5,
|
||||||
|
success_threshold: 2,
|
||||||
|
timeout_duration: Duration::from_secs(30),
|
||||||
|
window_duration: Duration::from_secs(60),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Circuit breaker state
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum CircuitState {
|
||||||
|
/// Normal operation - requests are allowed
|
||||||
|
Closed,
|
||||||
|
/// Circuit is open - requests are rejected
|
||||||
|
Open,
|
||||||
|
/// Testing if service has recovered - limited requests allowed
|
||||||
|
HalfOpen,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for CircuitState {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
CircuitState::Closed => write!(f, "Closed"),
|
||||||
|
CircuitState::Open => write!(f, "Open"),
|
||||||
|
CircuitState::HalfOpen => write!(f, "HalfOpen"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Circuit breaker implementation
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CircuitBreaker {
|
||||||
|
state: Arc<RwLock<CircuitState>>,
|
||||||
|
consecutive_failures: Arc<AtomicU32>,
|
||||||
|
consecutive_successes: Arc<AtomicU32>,
|
||||||
|
total_failures: Arc<AtomicU64>,
|
||||||
|
total_successes: Arc<AtomicU64>,
|
||||||
|
last_failure_time: Arc<RwLock<Option<Instant>>>,
|
||||||
|
last_state_change: Arc<RwLock<Instant>>,
|
||||||
|
config: CircuitBreakerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CircuitBreaker {
|
||||||
|
/// Create a new circuit breaker with default configuration
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_config(CircuitBreakerConfig::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new circuit breaker with custom configuration
|
||||||
|
pub fn with_config(config: CircuitBreakerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
state: Arc::new(RwLock::new(CircuitState::Closed)),
|
||||||
|
consecutive_failures: Arc::new(AtomicU32::new(0)),
|
||||||
|
consecutive_successes: Arc::new(AtomicU32::new(0)),
|
||||||
|
total_failures: Arc::new(AtomicU64::new(0)),
|
||||||
|
total_successes: Arc::new(AtomicU64::new(0)),
|
||||||
|
last_failure_time: Arc::new(RwLock::new(None)),
|
||||||
|
last_state_change: Arc::new(RwLock::new(Instant::now())),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a request can be executed
|
||||||
|
pub fn can_execute(&self) -> bool {
|
||||||
|
// First check if we need to transition from Open to HalfOpen
|
||||||
|
self.check_and_update_state();
|
||||||
|
|
||||||
|
let state = *self.state.read().unwrap();
|
||||||
|
match state {
|
||||||
|
CircuitState::Closed => true,
|
||||||
|
CircuitState::Open => false,
|
||||||
|
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current state
|
||||||
|
pub fn state(&self) -> CircuitState {
|
||||||
|
self.check_and_update_state();
|
||||||
|
*self.state.read().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record the outcome of a request
|
||||||
|
pub fn record_outcome(&self, success: bool) {
|
||||||
|
if success {
|
||||||
|
self.record_success();
|
||||||
|
} else {
|
||||||
|
self.record_failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a successful request
|
||||||
|
pub fn record_success(&self) {
|
||||||
|
self.total_successes.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.consecutive_failures.store(0, Ordering::Release);
|
||||||
|
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
|
||||||
|
|
||||||
|
let current_state = *self.state.read().unwrap();
|
||||||
|
|
||||||
|
match current_state {
|
||||||
|
CircuitState::HalfOpen => {
|
||||||
|
// Check if we've reached the success threshold to close the circuit
|
||||||
|
if successes >= self.config.success_threshold {
|
||||||
|
self.transition_to(CircuitState::Closed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CircuitState::Closed => {
|
||||||
|
// Already closed, nothing to do
|
||||||
|
}
|
||||||
|
CircuitState::Open => {
|
||||||
|
// Shouldn't happen, but if it does, stay open
|
||||||
|
tracing::warn!("Success recorded while circuit is open");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a failed request
|
||||||
|
pub fn record_failure(&self) {
|
||||||
|
self.total_failures.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.consecutive_successes.store(0, Ordering::Release);
|
||||||
|
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
|
||||||
|
|
||||||
|
// Update last failure time
|
||||||
|
{
|
||||||
|
let mut last_failure = self.last_failure_time.write().unwrap();
|
||||||
|
*last_failure = Some(Instant::now());
|
||||||
|
}
|
||||||
|
|
||||||
|
let current_state = *self.state.read().unwrap();
|
||||||
|
|
||||||
|
match current_state {
|
||||||
|
CircuitState::Closed => {
|
||||||
|
// Check if we've reached the failure threshold to open the circuit
|
||||||
|
if failures >= self.config.failure_threshold {
|
||||||
|
self.transition_to(CircuitState::Open);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CircuitState::HalfOpen => {
|
||||||
|
// Single failure in half-open state reopens the circuit
|
||||||
|
self.transition_to(CircuitState::Open);
|
||||||
|
}
|
||||||
|
CircuitState::Open => {
|
||||||
|
// Already open, nothing to do
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check and update state based on timeout
|
||||||
|
fn check_and_update_state(&self) {
|
||||||
|
let current_state = *self.state.read().unwrap();
|
||||||
|
|
||||||
|
if current_state == CircuitState::Open {
|
||||||
|
// Check if timeout has expired
|
||||||
|
let last_change = *self.last_state_change.read().unwrap();
|
||||||
|
if last_change.elapsed() >= self.config.timeout_duration {
|
||||||
|
self.transition_to(CircuitState::HalfOpen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transition to a new state
|
||||||
|
fn transition_to(&self, new_state: CircuitState) {
|
||||||
|
let mut state = self.state.write().unwrap();
|
||||||
|
let old_state = *state;
|
||||||
|
|
||||||
|
if old_state != new_state {
|
||||||
|
*state = new_state;
|
||||||
|
|
||||||
|
// Update last state change time
|
||||||
|
let mut last_change = self.last_state_change.write().unwrap();
|
||||||
|
*last_change = Instant::now();
|
||||||
|
|
||||||
|
// Reset counters based on transition
|
||||||
|
match new_state {
|
||||||
|
CircuitState::Closed => {
|
||||||
|
self.consecutive_failures.store(0, Ordering::Release);
|
||||||
|
self.consecutive_successes.store(0, Ordering::Release);
|
||||||
|
}
|
||||||
|
CircuitState::Open => {
|
||||||
|
self.consecutive_successes.store(0, Ordering::Release);
|
||||||
|
}
|
||||||
|
CircuitState::HalfOpen => {
|
||||||
|
self.consecutive_failures.store(0, Ordering::Release);
|
||||||
|
self.consecutive_successes.store(0, Ordering::Release);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Circuit breaker state transition: {} -> {}",
|
||||||
|
old_state,
|
||||||
|
new_state
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the number of consecutive failures
|
||||||
|
pub fn failure_count(&self) -> u32 {
|
||||||
|
self.consecutive_failures.load(Ordering::Acquire)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the number of consecutive successes
|
||||||
|
pub fn success_count(&self) -> u32 {
|
||||||
|
self.consecutive_successes.load(Ordering::Acquire)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get total failures
|
||||||
|
pub fn total_failures(&self) -> u64 {
|
||||||
|
self.total_failures.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get total successes
|
||||||
|
pub fn total_successes(&self) -> u64 {
|
||||||
|
self.total_successes.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get time since last failure
|
||||||
|
pub fn time_since_last_failure(&self) -> Option<Duration> {
|
||||||
|
self.last_failure_time.read().unwrap().map(|t| t.elapsed())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get time since last state change
|
||||||
|
pub fn time_since_last_state_change(&self) -> Duration {
|
||||||
|
self.last_state_change.read().unwrap().elapsed()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the circuit is in a half-open state
|
||||||
|
pub fn is_half_open(&self) -> bool {
|
||||||
|
self.state() == CircuitState::HalfOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a test success (for health check probing)
|
||||||
|
pub fn record_test_success(&self) {
|
||||||
|
if self.is_half_open() {
|
||||||
|
self.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a test failure (for health check probing)
|
||||||
|
pub fn record_test_failure(&self) {
|
||||||
|
if self.is_half_open() {
|
||||||
|
self.record_failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the circuit breaker to closed state
|
||||||
|
pub fn reset(&self) {
|
||||||
|
self.transition_to(CircuitState::Closed);
|
||||||
|
self.consecutive_failures.store(0, Ordering::Release);
|
||||||
|
self.consecutive_successes.store(0, Ordering::Release);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Force the circuit to open (for manual intervention)
|
||||||
|
pub fn force_open(&self) {
|
||||||
|
self.transition_to(CircuitState::Open);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get circuit breaker statistics
|
||||||
|
pub fn stats(&self) -> CircuitBreakerStats {
|
||||||
|
CircuitBreakerStats {
|
||||||
|
state: self.state(),
|
||||||
|
consecutive_failures: self.failure_count(),
|
||||||
|
consecutive_successes: self.success_count(),
|
||||||
|
total_failures: self.total_failures(),
|
||||||
|
total_successes: self.total_successes(),
|
||||||
|
time_since_last_failure: self.time_since_last_failure(),
|
||||||
|
time_since_last_state_change: self.time_since_last_state_change(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for CircuitBreaker {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
state: Arc::clone(&self.state),
|
||||||
|
consecutive_failures: Arc::clone(&self.consecutive_failures),
|
||||||
|
consecutive_successes: Arc::clone(&self.consecutive_successes),
|
||||||
|
total_failures: Arc::clone(&self.total_failures),
|
||||||
|
total_successes: Arc::clone(&self.total_successes),
|
||||||
|
last_failure_time: Arc::clone(&self.last_failure_time),
|
||||||
|
last_state_change: Arc::clone(&self.last_state_change),
|
||||||
|
config: self.config.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CircuitBreaker {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Circuit breaker statistics
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CircuitBreakerStats {
|
||||||
|
pub state: CircuitState,
|
||||||
|
pub consecutive_failures: u32,
|
||||||
|
pub consecutive_successes: u32,
|
||||||
|
pub total_failures: u64,
|
||||||
|
pub total_successes: u64,
|
||||||
|
pub time_since_last_failure: Option<Duration>,
|
||||||
|
pub time_since_last_state_change: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_initial_state() {
|
||||||
|
let cb = CircuitBreaker::new();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
assert_eq!(cb.failure_count(), 0);
|
||||||
|
assert_eq!(cb.success_count(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_opens_on_threshold() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Record failures up to threshold
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
cb.record_failure();
|
||||||
|
|
||||||
|
// Circuit should now be open
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
assert_eq!(cb.failure_count(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_half_open_after_timeout() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 1,
|
||||||
|
timeout_duration: Duration::from_millis(100),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Open the circuit
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
|
// Wait for timeout
|
||||||
|
thread::sleep(Duration::from_millis(150));
|
||||||
|
|
||||||
|
// Circuit should be half-open
|
||||||
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_closes_on_success_threshold() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 1,
|
||||||
|
success_threshold: 2,
|
||||||
|
timeout_duration: Duration::from_millis(50),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Open the circuit
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
|
// Wait for timeout
|
||||||
|
thread::sleep(Duration::from_millis(100));
|
||||||
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
|
||||||
|
// Record successes
|
||||||
|
cb.record_success();
|
||||||
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
cb.record_success();
|
||||||
|
|
||||||
|
// Circuit should now be closed
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_reopens_on_half_open_failure() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 1,
|
||||||
|
timeout_duration: Duration::from_millis(50),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Open the circuit
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
|
// Wait for timeout
|
||||||
|
thread::sleep(Duration::from_millis(100));
|
||||||
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
|
||||||
|
// Record a failure in half-open state
|
||||||
|
cb.record_failure();
|
||||||
|
|
||||||
|
// Circuit should reopen immediately
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_success_resets_failure_count() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Record some failures
|
||||||
|
cb.record_failure();
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.failure_count(), 2);
|
||||||
|
|
||||||
|
// Success should reset failure count
|
||||||
|
cb.record_success();
|
||||||
|
assert_eq!(cb.failure_count(), 0);
|
||||||
|
assert_eq!(cb.success_count(), 1);
|
||||||
|
|
||||||
|
// Can now record more failures without opening
|
||||||
|
cb.record_failure();
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_manual_reset() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
// Open the circuit
|
||||||
|
cb.record_failure();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
|
// Manual reset
|
||||||
|
cb.reset();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
assert_eq!(cb.failure_count(), 0);
|
||||||
|
assert_eq!(cb.success_count(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_force_open() {
|
||||||
|
let cb = CircuitBreaker::new();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
|
|
||||||
|
cb.force_open();
|
||||||
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stats() {
|
||||||
|
let config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
|
cb.record_success();
|
||||||
|
cb.record_failure();
|
||||||
|
cb.record_failure();
|
||||||
|
|
||||||
|
let stats = cb.stats();
|
||||||
|
assert_eq!(stats.state, CircuitState::Open);
|
||||||
|
assert_eq!(stats.consecutive_failures, 2);
|
||||||
|
assert_eq!(stats.consecutive_successes, 0);
|
||||||
|
assert_eq!(stats.total_failures, 2);
|
||||||
|
assert_eq!(stats.total_successes, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_clone() {
|
||||||
|
let cb1 = CircuitBreaker::new();
|
||||||
|
cb1.record_failure();
|
||||||
|
|
||||||
|
let cb2 = cb1.clone();
|
||||||
|
assert_eq!(cb2.failure_count(), 1);
|
||||||
|
|
||||||
|
// Changes to cb1 affect cb2 (shared state)
|
||||||
|
cb1.record_failure();
|
||||||
|
assert_eq!(cb2.failure_count(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thread_safety() {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
let cb = Arc::new(CircuitBreaker::new());
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
// Spawn threads that record failures
|
||||||
|
for _ in 0..10 {
|
||||||
|
let cb_clone = Arc::clone(&cb);
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
for _ in 0..100 {
|
||||||
|
cb_clone.record_failure();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all threads
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have recorded 1000 failures
|
||||||
|
assert_eq!(cb.total_failures(), 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,12 +3,17 @@
|
|||||||
//! This module contains the fundamental types and traits used throughout the router:
|
//! This module contains the fundamental types and traits used throughout the router:
|
||||||
//! - Worker trait and implementations
|
//! - Worker trait and implementations
|
||||||
//! - Error types
|
//! - Error types
|
||||||
|
//! - Circuit breaker for reliability
|
||||||
//! - Common utilities
|
//! - Common utilities
|
||||||
|
|
||||||
|
pub mod circuit_breaker;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod worker;
|
pub mod worker;
|
||||||
|
|
||||||
// Re-export commonly used types at the module level
|
// Re-export commonly used types at the module level
|
||||||
|
pub use circuit_breaker::{
|
||||||
|
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||||
|
};
|
||||||
pub use error::{WorkerError, WorkerResult};
|
pub use error::{WorkerError, WorkerResult};
|
||||||
pub use worker::{
|
pub use worker::{
|
||||||
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
|
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::{WorkerError, WorkerResult};
|
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures;
|
use futures;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
@@ -66,6 +66,19 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
|||||||
/// Clone the worker (for trait objects)
|
/// Clone the worker (for trait objects)
|
||||||
fn clone_worker(&self) -> Box<dyn Worker>;
|
fn clone_worker(&self) -> Box<dyn Worker>;
|
||||||
|
|
||||||
|
/// Get the circuit breaker for this worker
|
||||||
|
fn circuit_breaker(&self) -> &CircuitBreaker;
|
||||||
|
|
||||||
|
/// Check if the worker is available (healthy + circuit closed/half-open)
|
||||||
|
fn is_available(&self) -> bool {
|
||||||
|
self.is_healthy() && self.circuit_breaker().can_execute()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record the outcome of a request to this worker
|
||||||
|
fn record_outcome(&self, success: bool) {
|
||||||
|
self.circuit_breaker().record_outcome(success);
|
||||||
|
}
|
||||||
|
|
||||||
// === DP-aware methods ===
|
// === DP-aware methods ===
|
||||||
|
|
||||||
/// Check if this worker is DP-aware
|
/// Check if this worker is DP-aware
|
||||||
@@ -172,6 +185,7 @@ pub struct BasicWorker {
|
|||||||
load_counter: Arc<AtomicUsize>,
|
load_counter: Arc<AtomicUsize>,
|
||||||
processed_counter: Arc<AtomicUsize>,
|
processed_counter: Arc<AtomicUsize>,
|
||||||
healthy: Arc<AtomicBool>,
|
healthy: Arc<AtomicBool>,
|
||||||
|
circuit_breaker: CircuitBreaker,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BasicWorker {
|
impl BasicWorker {
|
||||||
@@ -188,6 +202,7 @@ impl BasicWorker {
|
|||||||
load_counter: Arc::new(AtomicUsize::new(0)),
|
load_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
processed_counter: Arc::new(AtomicUsize::new(0)),
|
processed_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
healthy: Arc::new(AtomicBool::new(true)),
|
healthy: Arc::new(AtomicBool::new(true)),
|
||||||
|
circuit_breaker: CircuitBreaker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,6 +216,11 @@ impl BasicWorker {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
|
||||||
|
self.circuit_breaker = CircuitBreaker::with_config(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||||
if self.url().contains("@") {
|
if self.url().contains("@") {
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
@@ -304,6 +324,10 @@ impl Worker for BasicWorker {
|
|||||||
fn clone_worker(&self) -> Box<dyn Worker> {
|
fn clone_worker(&self) -> Box<dyn Worker> {
|
||||||
Box::new(self.clone())
|
Box::new(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn circuit_breaker(&self) -> &CircuitBreaker {
|
||||||
|
&self.circuit_breaker
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A DP-aware worker that handles data-parallel routing
|
/// A DP-aware worker that handles data-parallel routing
|
||||||
@@ -421,6 +445,10 @@ impl Worker for DPAwareWorker {
|
|||||||
Box::new(self.clone())
|
Box::new(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn circuit_breaker(&self) -> &CircuitBreaker {
|
||||||
|
self.base_worker.circuit_breaker()
|
||||||
|
}
|
||||||
|
|
||||||
// DP-aware specific implementations
|
// DP-aware specific implementations
|
||||||
|
|
||||||
fn is_dp_aware(&self) -> bool {
|
fn is_dp_aware(&self) -> bool {
|
||||||
@@ -469,6 +497,17 @@ impl WorkerFactory {
|
|||||||
Box::new(BasicWorker::new(url, WorkerType::Regular))
|
Box::new(BasicWorker::new(url, WorkerType::Regular))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a regular worker with custom circuit breaker configuration
|
||||||
|
pub fn create_regular_with_config(
|
||||||
|
url: String,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
Box::new(
|
||||||
|
BasicWorker::new(url, WorkerType::Regular)
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a prefill worker with optional bootstrap port
|
/// Create a prefill worker with optional bootstrap port
|
||||||
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
||||||
Box::new(BasicWorker::new(
|
Box::new(BasicWorker::new(
|
||||||
@@ -477,11 +516,34 @@ impl WorkerFactory {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a prefill worker with custom circuit breaker configuration
|
||||||
|
pub fn create_prefill_with_config(
|
||||||
|
url: String,
|
||||||
|
bootstrap_port: Option<u16>,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
Box::new(
|
||||||
|
BasicWorker::new(url, WorkerType::Prefill { bootstrap_port })
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a decode worker
|
/// Create a decode worker
|
||||||
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
||||||
Box::new(BasicWorker::new(url, WorkerType::Decode))
|
Box::new(BasicWorker::new(url, WorkerType::Decode))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a decode worker with custom circuit breaker configuration
|
||||||
|
pub fn create_decode_with_config(
|
||||||
|
url: String,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
Box::new(
|
||||||
|
BasicWorker::new(url, WorkerType::Decode)
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create workers from URLs with automatic type detection
|
/// Create workers from URLs with automatic type detection
|
||||||
pub fn create_from_urls(
|
pub fn create_from_urls(
|
||||||
regular_urls: Vec<String>,
|
regular_urls: Vec<String>,
|
||||||
@@ -796,6 +858,7 @@ pub fn start_health_checker(
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
@@ -1574,6 +1637,94 @@ mod tests {
|
|||||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===== Circuit Breaker Integration Tests =====
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_worker_circuit_breaker() {
|
||||||
|
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular);
|
||||||
|
|
||||||
|
// Initial state should be available
|
||||||
|
assert!(worker.is_available());
|
||||||
|
assert_eq!(
|
||||||
|
worker.circuit_breaker().state(),
|
||||||
|
crate::core::CircuitState::Closed
|
||||||
|
);
|
||||||
|
|
||||||
|
// Record some failures
|
||||||
|
worker.record_outcome(false);
|
||||||
|
worker.record_outcome(false);
|
||||||
|
|
||||||
|
// Still available (default threshold is 5)
|
||||||
|
assert!(worker.is_available());
|
||||||
|
|
||||||
|
// Record more failures to open circuit
|
||||||
|
worker.record_outcome(false);
|
||||||
|
worker.record_outcome(false);
|
||||||
|
worker.record_outcome(false);
|
||||||
|
|
||||||
|
// Circuit should be open, worker not available
|
||||||
|
assert!(!worker.is_available());
|
||||||
|
assert!(worker.is_healthy()); // Still healthy
|
||||||
|
assert!(!worker.circuit_breaker().can_execute()); // But circuit is open
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_worker_with_circuit_breaker_config() {
|
||||||
|
let config = crate::core::CircuitBreakerConfig {
|
||||||
|
failure_threshold: 2,
|
||||||
|
success_threshold: 1,
|
||||||
|
timeout_duration: Duration::from_millis(100),
|
||||||
|
window_duration: Duration::from_secs(60),
|
||||||
|
};
|
||||||
|
|
||||||
|
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular)
|
||||||
|
.with_circuit_breaker_config(config);
|
||||||
|
|
||||||
|
// Should open after 2 failures
|
||||||
|
worker.record_outcome(false);
|
||||||
|
assert!(worker.is_available());
|
||||||
|
worker.record_outcome(false);
|
||||||
|
assert!(!worker.is_available());
|
||||||
|
|
||||||
|
// Wait for timeout
|
||||||
|
thread::sleep(Duration::from_millis(150));
|
||||||
|
|
||||||
|
// Should be half-open
|
||||||
|
assert!(worker.is_available());
|
||||||
|
assert_eq!(
|
||||||
|
worker.circuit_breaker().state(),
|
||||||
|
crate::core::CircuitState::HalfOpen
|
||||||
|
);
|
||||||
|
|
||||||
|
// Success should close it
|
||||||
|
worker.record_outcome(true);
|
||||||
|
assert_eq!(
|
||||||
|
worker.circuit_breaker().state(),
|
||||||
|
crate::core::CircuitState::Closed
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dp_aware_worker_circuit_breaker() {
|
||||||
|
let dp_worker =
|
||||||
|
DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||||
|
|
||||||
|
// Should have circuit breaker
|
||||||
|
assert!(dp_worker.is_available());
|
||||||
|
|
||||||
|
// Record failures
|
||||||
|
for _ in 0..5 {
|
||||||
|
dp_worker.record_outcome(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not be available
|
||||||
|
assert!(!dp_worker.is_available());
|
||||||
|
assert_eq!(
|
||||||
|
dp_worker.circuit_breaker().state(),
|
||||||
|
crate::core::CircuitState::Open
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// ===== Integration tests =====
|
// ===== Integration tests =====
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ impl Router {
|
|||||||
max_concurrent_requests: self.max_concurrent_requests,
|
max_concurrent_requests: self.max_concurrent_requests,
|
||||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||||
retry: config::RetryConfig::default(),
|
retry: config::RetryConfig::default(),
|
||||||
|
circuit_breaker: config::CircuitBreakerConfig::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ impl RouterFactory {
|
|||||||
ctx.router_config.dp_aware,
|
ctx.router_config.dp_aware,
|
||||||
ctx.router_config.api_key.clone(),
|
ctx.router_config.api_key.clone(),
|
||||||
ctx.router_config.retry.clone(),
|
ctx.router_config.retry.clone(),
|
||||||
|
ctx.router_config.circuit_breaker.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
@@ -81,6 +82,7 @@ impl RouterFactory {
|
|||||||
ctx.router_config.worker_startup_timeout_secs,
|
ctx.router_config.worker_startup_timeout_secs,
|
||||||
ctx.router_config.worker_startup_check_interval_secs,
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
ctx.router_config.retry.clone(),
|
ctx.router_config.retry.clone(),
|
||||||
|
ctx.router_config.circuit_breaker.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
// PD (Prefill-Decode) Router Implementation
|
// PD (Prefill-Decode) Router Implementation
|
||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
||||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
@@ -41,6 +41,7 @@ pub struct PDRouter {
|
|||||||
// Dedicated client for prefill fire-and-forget (non-logprob) requests
|
// Dedicated client for prefill fire-and-forget (non-logprob) requests
|
||||||
pub prefill_client: Client,
|
pub prefill_client: Client,
|
||||||
pub retry_config: RetryConfig,
|
pub retry_config: RetryConfig,
|
||||||
|
pub circuit_breaker_config: CircuitBreakerConfig,
|
||||||
_prefill_health_checker: Option<HealthChecker>,
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
}
|
}
|
||||||
@@ -68,8 +69,12 @@ impl PDRouter {
|
|||||||
// Wait for the new server to be healthy
|
// Wait for the new server to be healthy
|
||||||
self.wait_for_server_health(&url).await?;
|
self.wait_for_server_health(&url).await?;
|
||||||
|
|
||||||
// Create Worker for the new prefill server
|
// Create Worker for the new prefill server with circuit breaker configuration
|
||||||
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
|
let worker = WorkerFactory::create_prefill_with_config(
|
||||||
|
url.clone(),
|
||||||
|
bootstrap_port,
|
||||||
|
self.circuit_breaker_config.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
// Add to prefill workers list
|
// Add to prefill workers list
|
||||||
let mut workers = self
|
let mut workers = self
|
||||||
@@ -99,8 +104,11 @@ impl PDRouter {
|
|||||||
// Wait for the new server to be healthy
|
// Wait for the new server to be healthy
|
||||||
self.wait_for_server_health(&url).await?;
|
self.wait_for_server_health(&url).await?;
|
||||||
|
|
||||||
// Create Worker for the new decode server
|
// Create Worker for the new decode server with circuit breaker configuration
|
||||||
let worker = WorkerFactory::create_decode(url.clone());
|
let worker = WorkerFactory::create_decode_with_config(
|
||||||
|
url.clone(),
|
||||||
|
self.circuit_breaker_config.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
// Add to decode workers list
|
// Add to decode workers list
|
||||||
let mut workers = self
|
let mut workers = self
|
||||||
@@ -189,16 +197,31 @@ impl PDRouter {
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
retry_config: RetryConfig,
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: std::time::Duration::from_secs(
|
||||||
|
circuit_breaker_config.timeout_duration_secs,
|
||||||
|
),
|
||||||
|
window_duration: std::time::Duration::from_secs(
|
||||||
|
circuit_breaker_config.window_duration_secs,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// Convert URLs to Worker trait objects
|
// Convert URLs to Worker trait objects
|
||||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
|
.map(|(url, port)| {
|
||||||
|
WorkerFactory::create_prefill_with_config(url, port, core_cb_config.clone())
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(WorkerFactory::create_decode)
|
.map(|url| WorkerFactory::create_decode_with_config(url, core_cb_config.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Wait for PD workers to be healthy (skip if empty - for service discovery mode)
|
// Wait for PD workers to be healthy (skip if empty - for service discovery mode)
|
||||||
@@ -280,6 +303,7 @@ impl PDRouter {
|
|||||||
client,
|
client,
|
||||||
prefill_client,
|
prefill_client,
|
||||||
retry_config,
|
retry_config,
|
||||||
|
circuit_breaker_config: core_cb_config,
|
||||||
_prefill_health_checker: Some(prefill_health_checker),
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
_decode_health_checker: Some(decode_health_checker),
|
_decode_health_checker: Some(decode_health_checker),
|
||||||
})
|
})
|
||||||
@@ -1848,6 +1872,7 @@ mod tests {
|
|||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
prefill_client: Client::new(),
|
prefill_client: Client::new(),
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||||
_prefill_health_checker: None,
|
_prefill_health_checker: None,
|
||||||
_decode_health_checker: None,
|
_decode_health_checker: None,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
||||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
@@ -42,6 +42,7 @@ pub struct Router {
|
|||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
retry_config: RetryConfig,
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
_health_checker: Option<HealthChecker>,
|
_health_checker: Option<HealthChecker>,
|
||||||
@@ -58,6 +59,7 @@ impl Router {
|
|||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
retry_config: RetryConfig,
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
// Update active workers gauge
|
// Update active workers gauge
|
||||||
RouterMetrics::set_active_workers(worker_urls.len());
|
RouterMetrics::set_active_workers(worker_urls.len());
|
||||||
@@ -75,10 +77,24 @@ impl Router {
|
|||||||
worker_urls
|
worker_urls
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: std::time::Duration::from_secs(
|
||||||
|
circuit_breaker_config.timeout_duration_secs,
|
||||||
|
),
|
||||||
|
window_duration: std::time::Duration::from_secs(
|
||||||
|
circuit_breaker_config.window_duration_secs,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// Create Worker trait objects from URLs
|
// Create Worker trait objects from URLs
|
||||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||||
.iter()
|
.iter()
|
||||||
.map(|url| WorkerFactory::create_regular(url.clone()))
|
.map(|url| {
|
||||||
|
WorkerFactory::create_regular_with_config(url.clone(), core_cb_config.clone())
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Initialize policy with workers if needed (e.g., for cache-aware)
|
// Initialize policy with workers if needed (e.g., for cache-aware)
|
||||||
@@ -125,6 +141,7 @@ impl Router {
|
|||||||
dp_aware,
|
dp_aware,
|
||||||
api_key,
|
api_key,
|
||||||
retry_config,
|
retry_config,
|
||||||
|
circuit_breaker_config: core_cb_config,
|
||||||
_worker_loads: worker_loads,
|
_worker_loads: worker_loads,
|
||||||
_load_monitor_handle: load_monitor_handle,
|
_load_monitor_handle: load_monitor_handle,
|
||||||
_health_checker: Some(health_checker),
|
_health_checker: Some(health_checker),
|
||||||
@@ -752,7 +769,10 @@ impl Router {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
info!("Added worker: {}", dp_url);
|
info!("Added worker: {}", dp_url);
|
||||||
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
|
let new_worker = WorkerFactory::create_regular_with_config(
|
||||||
|
dp_url.to_string(),
|
||||||
|
self.circuit_breaker_config.clone(),
|
||||||
|
);
|
||||||
workers_guard.push(new_worker);
|
workers_guard.push(new_worker);
|
||||||
worker_added = true;
|
worker_added = true;
|
||||||
}
|
}
|
||||||
@@ -764,7 +784,10 @@ impl Router {
|
|||||||
return Err(format!("Worker {} already exists", worker_url));
|
return Err(format!("Worker {} already exists", worker_url));
|
||||||
}
|
}
|
||||||
info!("Added worker: {}", worker_url);
|
info!("Added worker: {}", worker_url);
|
||||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
let new_worker = WorkerFactory::create_regular_with_config(
|
||||||
|
worker_url.to_string(),
|
||||||
|
self.circuit_breaker_config.clone(),
|
||||||
|
);
|
||||||
workers_guard.push(new_worker);
|
workers_guard.push(new_worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1223,6 +1246,7 @@ mod tests {
|
|||||||
api_key: None,
|
api_key: None,
|
||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||||
_worker_loads: Arc::new(rx),
|
_worker_loads: Arc::new(rx),
|
||||||
_load_monitor_handle: None,
|
_load_monitor_handle: None,
|
||||||
_health_checker: None,
|
_health_checker: None,
|
||||||
|
|||||||
@@ -589,6 +589,7 @@ mod tests {
|
|||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
crate::config::types::RetryConfig::default(),
|
crate::config::types::RetryConfig::default(),
|
||||||
|
crate::config::types::CircuitBreakerConfig::default(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Arc::new(router) as Arc<dyn RouterTrait>
|
Arc::new(router) as Arc<dyn RouterTrait>
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ use axum::{
|
|||||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{
|
||||||
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
@@ -45,6 +47,7 @@ impl TestContext {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Self::new_with_config(config, worker_configs).await
|
Self::new_with_config(config, worker_configs).await
|
||||||
@@ -1087,6 +1090,7 @@ mod error_tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
@@ -1434,6 +1438,7 @@ mod pd_mode_tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create app context
|
// Create app context
|
||||||
@@ -1588,6 +1593,7 @@ mod request_id_tests {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ mod common;
|
|||||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{
|
||||||
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ impl TestContext {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
|||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{
|
||||||
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -37,6 +39,7 @@ impl TestContext {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
mod test_pd_routing {
|
mod test_pd_routing {
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{
|
||||||
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
|
};
|
||||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
use sglang_router_rs::routers::pd_types::get_hostname;
|
use sglang_router_rs::routers::pd_types::get_hostname;
|
||||||
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
||||||
@@ -179,6 +181,7 @@ mod test_pd_routing {
|
|||||||
max_concurrent_requests: 64,
|
max_concurrent_requests: 64,
|
||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
|||||||
Reference in New Issue
Block a user