[router] regular router circuit breaker (#8997)
This commit is contained in:
6
package-lock.json
generated
Normal file
6
package-lock.json
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"name": "sglang",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
@@ -116,6 +116,39 @@ python -m sglang_router.launch_router \
|
||||
--prometheus-port 9000
|
||||
```
|
||||
|
||||
### Retries and Circuit Breakers
|
||||
|
||||
- Retries (regular router) are enabled by default with exponential backoff and jitter. You can tune them via CLI:
|
||||
|
||||
```bash
|
||||
python -m sglang_router.launch_router \
|
||||
--worker-urls http://localhost:8080 http://localhost:8081 \
|
||||
--retry-max-retries 3 \
|
||||
--retry-initial-backoff-ms 100 \
|
||||
--retry-max-backoff-ms 10000 \
|
||||
--retry-backoff-multiplier 2.0 \
|
||||
--retry-jitter-factor 0.1
|
||||
```
|
||||
|
||||
- Circuit Breaker defaults protect workers and auto-recover. Tune thresholds/timeouts:
|
||||
|
||||
```bash
|
||||
python -m sglang_router.launch_router \
|
||||
--worker-urls http://localhost:8080 http://localhost:8081 \
|
||||
--cb-failure-threshold 5 \
|
||||
--cb-success-threshold 2 \
|
||||
--cb-timeout-duration-secs 30 \
|
||||
--cb-window-duration-secs 60
|
||||
```
|
||||
|
||||
Behavior summary:
|
||||
- Closed → Open after N consecutive failures (failure-threshold)
|
||||
- Open → HalfOpen after timeout (timeout-duration-secs)
|
||||
- HalfOpen → Closed after M consecutive successes (success-threshold)
|
||||
- Any failure in HalfOpen reopens immediately
|
||||
|
||||
Retry predicate (regular router): retry on 408/429/500/502/503/504, otherwise return immediately. Backoff/jitter observed between attempts.
|
||||
|
||||
### Request ID Tracking
|
||||
|
||||
Track requests across distributed systems with configurable headers:
|
||||
|
||||
@@ -74,6 +74,19 @@ class RouterArgs:
|
||||
max_concurrent_requests: int = 64
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 3
|
||||
retry_initial_backoff_ms: int = 100
|
||||
retry_max_backoff_ms: int = 10_000
|
||||
retry_backoff_multiplier: float = 2.0
|
||||
retry_jitter_factor: float = 0.1
|
||||
disable_retries: bool = False
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 5
|
||||
cb_success_threshold: int = 2
|
||||
cb_timeout_duration_secs: int = 30
|
||||
cb_window_duration_secs: int = 60
|
||||
disable_circuit_breaker: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@@ -289,6 +302,63 @@ class RouterArgs:
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
@@ -372,6 +442,19 @@ class RouterArgs:
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
||||
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
|
||||
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
|
||||
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
|
||||
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
|
||||
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
|
||||
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
|
||||
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
|
||||
disable_retries=getattr(args, f"{prefix}disable_retries", False),
|
||||
disable_circuit_breaker=getattr(
|
||||
args, f"{prefix}disable_circuit_breaker", False
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -558,6 +641,17 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
retry_max_retries=router_args.retry_max_retries,
|
||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
|
||||
retry_jitter_factor=router_args.retry_jitter_factor,
|
||||
cb_failure_threshold=router_args.cb_failure_threshold,
|
||||
cb_success_threshold=router_args.cb_success_threshold,
|
||||
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=router_args.cb_window_duration_secs,
|
||||
disable_retries=router_args.disable_retries,
|
||||
disable_circuit_breaker=router_args.disable_circuit_breaker,
|
||||
)
|
||||
|
||||
router.start()
|
||||
|
||||
@@ -158,6 +158,7 @@ def main():
|
||||
default=31000,
|
||||
help="Base port number for data parallel workers",
|
||||
)
|
||||
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
@@ -104,6 +104,17 @@ class Router:
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 64,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 3,
|
||||
retry_initial_backoff_ms: int = 100,
|
||||
retry_max_backoff_ms: int = 10_000,
|
||||
retry_backoff_multiplier: float = 2.0,
|
||||
retry_jitter_factor: float = 0.1,
|
||||
cb_failure_threshold: int = 5,
|
||||
cb_success_threshold: int = 2,
|
||||
cb_timeout_duration_secs: int = 30,
|
||||
cb_window_duration_secs: int = 60,
|
||||
disable_retries: bool = False,
|
||||
disable_circuit_breaker: bool = False,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
@@ -149,6 +160,17 @@ class Router:
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
||||
retry_jitter_factor=retry_jitter_factor,
|
||||
cb_failure_threshold=cb_failure_threshold,
|
||||
cb_success_threshold=cb_success_threshold,
|
||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=cb_window_duration_secs,
|
||||
disable_retries=disable_retries,
|
||||
disable_circuit_breaker=disable_circuit_breaker,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
@@ -53,6 +53,17 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
prefill=None,
|
||||
decode=None,
|
||||
worker_urls=[],
|
||||
retry_max_retries=3,
|
||||
retry_initial_backoff_ms=100,
|
||||
retry_max_backoff_ms=10_000,
|
||||
retry_backoff_multiplier=2.0,
|
||||
retry_jitter_factor=0.1,
|
||||
cb_failure_threshold=5,
|
||||
cb_success_threshold=2,
|
||||
cb_timeout_duration_secs=30,
|
||||
cb_window_duration_secs=60,
|
||||
disable_retries=False,
|
||||
disable_circuit_breaker=False,
|
||||
)
|
||||
|
||||
def create_router_args(self, **kwargs):
|
||||
|
||||
@@ -31,6 +31,16 @@ def popen_launch_router(
|
||||
prometheus_port: int = None,
|
||||
prometheus_host: str = None,
|
||||
dp_aware: bool = False,
|
||||
# Router retry/CB tuning (optional)
|
||||
router_retry_max_retries: int = None,
|
||||
router_retry_initial_backoff_ms: int = None,
|
||||
router_retry_max_backoff_ms: int = None,
|
||||
router_retry_backoff_multiplier: float = None,
|
||||
router_retry_jitter_factor: float = None,
|
||||
router_cb_failure_threshold: int = None,
|
||||
router_cb_success_threshold: int = None,
|
||||
router_cb_timeout_duration_secs: int = None,
|
||||
router_cb_window_duration_secs: int = None,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
@@ -107,6 +117,21 @@ def popen_launch_router(
|
||||
if dp_aware:
|
||||
command.append("--router-dp-aware")
|
||||
|
||||
# Append router retry/CB tuning flags if provided
|
||||
def _add(flag: str, val):
|
||||
if val is not None:
|
||||
command.extend([flag, str(val)])
|
||||
|
||||
_add("--router-retry-max-retries", router_retry_max_retries)
|
||||
_add("--router-retry-initial-backoff-ms", router_retry_initial_backoff_ms)
|
||||
_add("--router-retry-max-backoff-ms", router_retry_max_backoff_ms)
|
||||
_add("--router-retry-backoff-multiplier", router_retry_backoff_multiplier)
|
||||
_add("--router-retry-jitter-factor", router_retry_jitter_factor)
|
||||
_add("--router-cb-failure-threshold", router_cb_failure_threshold)
|
||||
_add("--router-cb-success-threshold", router_cb_success_threshold)
|
||||
_add("--router-cb-timeout-duration-secs", router_cb_timeout_duration_secs)
|
||||
_add("--router-cb-window-duration-secs", router_cb_window_duration_secs)
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
@@ -43,6 +43,12 @@ pub struct RouterConfig {
|
||||
pub retry: RetryConfig,
|
||||
/// Circuit breaker configuration
|
||||
pub circuit_breaker: CircuitBreakerConfig,
|
||||
/// Disable retries (overrides retry.max_retries to 1 when true)
|
||||
#[serde(default)]
|
||||
pub disable_retries: bool,
|
||||
/// Disable circuit breaker (overrides circuit_breaker.failure_threshold to u32::MAX when true)
|
||||
#[serde(default)]
|
||||
pub disable_circuit_breaker: bool,
|
||||
}
|
||||
|
||||
/// Routing mode configuration
|
||||
@@ -197,6 +203,10 @@ pub struct RetryConfig {
|
||||
pub max_backoff_ms: u64,
|
||||
/// Backoff multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
/// Jitter factor applied to backoff (0.0 - 1.0)
|
||||
/// Effective delay D' = D * (1 + U[-j, +j])
|
||||
#[serde(default = "default_retry_jitter_factor")]
|
||||
pub jitter_factor: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
@@ -206,10 +216,15 @@ impl Default for RetryConfig {
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 10000,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_retry_jitter_factor() -> f32 {
|
||||
0.1
|
||||
}
|
||||
|
||||
/// Circuit breaker configuration for worker reliability
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CircuitBreakerConfig {
|
||||
@@ -276,6 +291,8 @@ impl Default for RouterConfig {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -312,6 +329,24 @@ impl RouterConfig {
|
||||
pub fn has_metrics(&self) -> bool {
|
||||
self.metrics.is_some()
|
||||
}
|
||||
|
||||
/// Compute the effective retry config considering disable flag
|
||||
pub fn effective_retry_config(&self) -> RetryConfig {
|
||||
let mut cfg = self.retry.clone();
|
||||
if self.disable_retries {
|
||||
cfg.max_retries = 1;
|
||||
}
|
||||
cfg
|
||||
}
|
||||
|
||||
/// Compute the effective circuit breaker config considering disable flag
|
||||
pub fn effective_circuit_breaker_config(&self) -> CircuitBreakerConfig {
|
||||
let mut cfg = self.circuit_breaker.clone();
|
||||
if self.disable_circuit_breaker {
|
||||
cfg.failure_threshold = u32::MAX;
|
||||
}
|
||||
cfg
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -388,6 +423,8 @@ mod tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -817,6 +854,8 @@ mod tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -870,6 +909,8 @@ mod tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -919,6 +960,8 @@ mod tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -23,6 +23,12 @@ impl ConfigValidator {
|
||||
|
||||
Self::validate_compatibility(config)?;
|
||||
|
||||
// Validate effective retry/CB configs (respect disable flags)
|
||||
let retry_cfg = config.effective_retry_config();
|
||||
let cb_cfg = config.effective_circuit_breaker_config();
|
||||
Self::validate_retry(&retry_cfg)?;
|
||||
Self::validate_circuit_breaker(&cb_cfg)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -263,6 +269,79 @@ impl ConfigValidator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate retry configuration
|
||||
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
|
||||
if retry.max_retries < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_retries".to_string(),
|
||||
value: retry.max_retries.to_string(),
|
||||
reason: "Must be >= 1 (set to 1 to effectively disable retries)".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.initial_backoff_ms == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.initial_backoff_ms".to_string(),
|
||||
value: retry.initial_backoff_ms.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.max_backoff_ms < retry.initial_backoff_ms {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_backoff_ms".to_string(),
|
||||
value: retry.max_backoff_ms.to_string(),
|
||||
reason: "Must be >= initial_backoff_ms".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.backoff_multiplier < 1.0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.backoff_multiplier".to_string(),
|
||||
value: retry.backoff_multiplier.to_string(),
|
||||
reason: "Must be >= 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
if !(0.0..=1.0).contains(&retry.jitter_factor) {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.jitter_factor".to_string(),
|
||||
value: retry.jitter_factor.to_string(),
|
||||
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate circuit breaker configuration
|
||||
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
|
||||
if cb.failure_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.failure_threshold".to_string(),
|
||||
value: cb.failure_threshold.to_string(),
|
||||
reason: "Must be >= 1 (set to u32::MAX to effectively disable CB)".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.success_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.success_threshold".to_string(),
|
||||
value: cb.success_threshold.to_string(),
|
||||
reason: "Must be >= 1".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.timeout_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.timeout_duration_secs".to_string(),
|
||||
value: cb.timeout_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.window_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.window_duration_secs".to_string(),
|
||||
value: cb.window_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate compatibility between different configuration sections
|
||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
/// Circuit breaker configuration
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -113,6 +114,7 @@ impl CircuitBreaker {
|
||||
self.total_successes.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
@@ -138,6 +140,7 @@ impl CircuitBreaker {
|
||||
self.total_failures.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
// Update last failure time
|
||||
{
|
||||
@@ -204,11 +207,18 @@ impl CircuitBreaker {
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Circuit breaker state transition: {} -> {}",
|
||||
old_state,
|
||||
new_state
|
||||
);
|
||||
let from = match old_state {
|
||||
CircuitState::Closed => "closed",
|
||||
CircuitState::Open => "open",
|
||||
CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
let to = match new_state {
|
||||
CircuitState::Closed => "closed",
|
||||
CircuitState::Open => "open",
|
||||
CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
info!("Circuit breaker state transition: {} -> {}", from, to);
|
||||
// Transition metrics are recorded at the worker level where the worker label is known
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
pub mod retry;
|
||||
pub mod worker;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
@@ -15,6 +16,7 @@ pub use circuit_breaker::{
|
||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||
};
|
||||
pub use error::{WorkerError, WorkerResult};
|
||||
pub use retry::{BackoffCalculator, RetryError, RetryExecutor};
|
||||
pub use worker::{
|
||||
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
|
||||
WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
|
||||
395
sgl-router/src/core/retry.rs
Normal file
395
sgl-router/src/core/retry.rs
Normal file
@@ -0,0 +1,395 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use axum::response::Response;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
/// Computes exponential backoff with optional jitter.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BackoffCalculator;
|
||||
|
||||
impl BackoffCalculator {
|
||||
/// Calculate backoff delay for a given attempt index (0-based).
|
||||
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
|
||||
// Base exponential backoff
|
||||
let pow = config.backoff_multiplier.powi(attempt as i32);
|
||||
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
|
||||
if delay_ms > config.max_backoff_ms {
|
||||
delay_ms = config.max_backoff_ms;
|
||||
}
|
||||
|
||||
// Apply jitter in range [-j, +j]
|
||||
let jitter = config.jitter_factor.max(0.0).min(1.0);
|
||||
if jitter > 0.0 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_scale: f32 = rng.gen_range(-jitter..=jitter);
|
||||
let jitter_ms = (delay_ms as f32 * jitter_scale)
|
||||
.round()
|
||||
.max(-(delay_ms as f32));
|
||||
let adjusted = (delay_ms as i64 + jitter_ms as i64).max(0) as u64;
|
||||
return Duration::from_millis(adjusted);
|
||||
}
|
||||
|
||||
Duration::from_millis(delay_ms)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RetryError {
|
||||
#[error("no available workers")]
|
||||
NoAvailableWorkers,
|
||||
#[error("maximum retry attempts exceeded")]
|
||||
MaxRetriesExceeded,
|
||||
}
|
||||
|
||||
/// A thin async retry executor for generic operations.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetryExecutor;
|
||||
|
||||
impl RetryExecutor {
|
||||
/// Execute an async operation with retries and backoff.
|
||||
/// The `operation` closure is invoked each attempt with the attempt index.
|
||||
pub async fn execute_with_retry<F, Fut, T>(
|
||||
config: &RetryConfig,
|
||||
mut operation: F,
|
||||
) -> Result<T, RetryError>
|
||||
where
|
||||
F: FnMut(u32) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, ()>>,
|
||||
{
|
||||
let max = config.max_retries.max(1);
|
||||
let mut attempt: u32 = 0;
|
||||
loop {
|
||||
match operation(attempt).await {
|
||||
Ok(val) => return Ok(val),
|
||||
Err(_) => {
|
||||
// Use the number of failures so far (0-indexed) to compute delay,
|
||||
// so the first retry uses `initial_backoff_ms`.
|
||||
let is_last = attempt + 1 >= max;
|
||||
if is_last {
|
||||
return Err(RetryError::MaxRetriesExceeded);
|
||||
}
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
attempt += 1; // advance to the next attempt after computing delay
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute an operation that returns an HTTP Response with retries and backoff.
|
||||
///
|
||||
/// Usage pattern:
|
||||
/// - `operation(attempt)`: perform one attempt (0-based). Construct and send the request,
|
||||
/// then return the `Response`. Do any per-attempt bookkeeping (e.g., load tracking,
|
||||
/// circuit-breaker outcome recording) inside this closure.
|
||||
/// - `should_retry(&response, attempt)`: decide if the given response should be retried
|
||||
/// (e.g., based on HTTP status). Returning false short-circuits and returns the response.
|
||||
/// - `on_backoff(delay, next_attempt)`: called before sleeping between attempts.
|
||||
/// Use this to record metrics.
|
||||
/// - `on_exhausted()`: called when the executor has exhausted all retry attempts.
|
||||
///
|
||||
/// Example:
|
||||
/// ```ignore
|
||||
/// let resp = RetryExecutor::execute_response_with_retry(
|
||||
/// &retry_cfg,
|
||||
/// |attempt| async move {
|
||||
/// let worker = select_cb_aware_worker()?;
|
||||
/// let resp = send_request(worker).await;
|
||||
/// worker.record_outcome(resp.status().is_success());
|
||||
/// resp
|
||||
/// },
|
||||
/// |res, _| matches!(res.status(), StatusCode::REQUEST_TIMEOUT | StatusCode::TOO_MANY_REQUESTS | StatusCode::INTERNAL_SERVER_ERROR | StatusCode::BAD_GATEWAY | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT),
|
||||
/// |delay, attempt| RouterMetrics::record_retry_backoff_duration(delay, attempt),
|
||||
/// || RouterMetrics::record_retries_exhausted("/route"),
|
||||
/// ).await;
|
||||
/// ```
|
||||
pub async fn execute_response_with_retry<Op, Fut, ShouldRetry, OnBackoff, OnExhausted>(
|
||||
config: &RetryConfig,
|
||||
mut operation: Op,
|
||||
should_retry: ShouldRetry,
|
||||
on_backoff: OnBackoff,
|
||||
mut on_exhausted: OnExhausted,
|
||||
) -> Response
|
||||
where
|
||||
Op: FnMut(u32) -> Fut,
|
||||
Fut: std::future::Future<Output = Response>,
|
||||
ShouldRetry: Fn(&Response, u32) -> bool,
|
||||
OnBackoff: Fn(Duration, u32),
|
||||
OnExhausted: FnMut(),
|
||||
{
|
||||
let max = config.max_retries.max(1);
|
||||
|
||||
let mut attempt: u32 = 0;
|
||||
loop {
|
||||
let response = operation(attempt).await;
|
||||
let is_last = attempt + 1 >= max;
|
||||
|
||||
if !should_retry(&response, attempt) {
|
||||
return response;
|
||||
}
|
||||
|
||||
if is_last {
|
||||
// Exhausted retries
|
||||
on_exhausted();
|
||||
return response;
|
||||
}
|
||||
|
||||
// Backoff before next attempt
|
||||
let next_attempt = attempt + 1;
|
||||
// Compute delay based on the number of failures so far (0-indexed)
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
debug!(
|
||||
attempt = attempt,
|
||||
next_attempt = next_attempt,
|
||||
delay_ms = delay.as_millis() as u64,
|
||||
"Retry backoff"
|
||||
);
|
||||
on_backoff(delay, next_attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
attempt = next_attempt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn base_retry_config() -> RetryConfig {
|
||||
RetryConfig {
|
||||
max_retries: 3,
|
||||
initial_backoff_ms: 1,
|
||||
max_backoff_ms: 4,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_no_jitter_progression_and_cap() {
|
||||
let cfg = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 250,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.0,
|
||||
};
|
||||
// attempt=0 => 100ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 0),
|
||||
Duration::from_millis(100)
|
||||
);
|
||||
// attempt=1 => 200ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 1),
|
||||
Duration::from_millis(200)
|
||||
);
|
||||
// attempt=2 => 400ms -> capped to 250ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 2),
|
||||
Duration::from_millis(250)
|
||||
);
|
||||
// large attempt still capped
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 10),
|
||||
Duration::from_millis(250)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_with_jitter_within_bounds() {
|
||||
let cfg = RetryConfig {
|
||||
max_retries: 5,
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 10_000,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.5,
|
||||
};
|
||||
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
|
||||
let base = 400.0;
|
||||
for _ in 0..50 {
|
||||
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
|
||||
assert!(d >= base * 0.5 - 1.0 && d <= base * 1.5 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_retry_success_after_failures() {
|
||||
let cfg = base_retry_config();
|
||||
let remaining = Arc::new(AtomicU32::new(2));
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
|
||||
let remaining = remaining.clone();
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
let remaining = remaining.clone();
|
||||
async move {
|
||||
if remaining
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
|
||||
.is_ok()
|
||||
{
|
||||
Err(())
|
||||
} else {
|
||||
Ok(42u32)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(res.is_ok());
|
||||
assert_eq!(res.unwrap(), 42);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_retry_exhausted() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { Err(()) }
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, Err(RetryError::MaxRetriesExceeded)));
|
||||
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_success_path_and_hooks() {
|
||||
let cfg = base_retry_config();
|
||||
let remaining = Arc::new(AtomicU32::new(2));
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let remaining = remaining.clone();
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
let remaining = remaining.clone();
|
||||
async move {
|
||||
if remaining
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
|
||||
.is_ok()
|
||||
{
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "fail").into_response()
|
||||
} else {
|
||||
(StatusCode::OK, "ok").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|res, _attempt| !res.status().is_success(), // retry until success
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_non_retryable_short_circuit() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| false, // never retry
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_exhausted_hooks() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| true, // keep retrying
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), cfg.max_retries - 1);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
}
|
||||
@@ -77,7 +77,35 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
|
||||
/// Record the outcome of a request to this worker
|
||||
fn record_outcome(&self, success: bool) {
|
||||
// Record outcome-level metric with worker label
|
||||
let outcome_str = if success { "success" } else { "failure" };
|
||||
RouterMetrics::record_cb_outcome(self.url(), outcome_str);
|
||||
|
||||
// Record into circuit breaker and infer state change for metrics
|
||||
let before = self.circuit_breaker().state();
|
||||
self.circuit_breaker().record_outcome(success);
|
||||
let after = self.circuit_breaker().state();
|
||||
|
||||
if before != after {
|
||||
let from = match before {
|
||||
crate::core::CircuitState::Closed => "closed",
|
||||
crate::core::CircuitState::Open => "open",
|
||||
crate::core::CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
let to = match after {
|
||||
crate::core::CircuitState::Closed => "closed",
|
||||
crate::core::CircuitState::Open => "open",
|
||||
crate::core::CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
RouterMetrics::record_cb_state_transition(self.url(), from, to);
|
||||
}
|
||||
|
||||
let state_code = match self.circuit_breaker().state() {
|
||||
crate::core::CircuitState::Closed => 0u8,
|
||||
crate::core::CircuitState::Open => 1u8,
|
||||
crate::core::CircuitState::HalfOpen => 2u8,
|
||||
};
|
||||
RouterMetrics::set_cb_state(self.url(), state_code);
|
||||
}
|
||||
|
||||
// === DP-aware methods ===
|
||||
|
||||
@@ -59,6 +59,19 @@ struct Router {
|
||||
decode_policy: Option<PolicyType>,
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
// Retry configuration
|
||||
retry_max_retries: u32,
|
||||
retry_initial_backoff_ms: u64,
|
||||
retry_max_backoff_ms: u64,
|
||||
retry_backoff_multiplier: f32,
|
||||
retry_jitter_factor: f32,
|
||||
disable_retries: bool,
|
||||
// Circuit breaker configuration
|
||||
cb_failure_threshold: u32,
|
||||
cb_success_threshold: u32,
|
||||
cb_timeout_duration_secs: u64,
|
||||
cb_window_duration_secs: u64,
|
||||
disable_circuit_breaker: bool,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -146,8 +159,21 @@ impl Router {
|
||||
request_id_headers: self.request_id_headers.clone(),
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: config::RetryConfig::default(),
|
||||
circuit_breaker: config::CircuitBreakerConfig::default(),
|
||||
retry: config::RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
initial_backoff_ms: self.retry_initial_backoff_ms,
|
||||
max_backoff_ms: self.retry_max_backoff_ms,
|
||||
backoff_multiplier: self.retry_backoff_multiplier,
|
||||
jitter_factor: self.retry_jitter_factor,
|
||||
},
|
||||
circuit_breaker: config::CircuitBreakerConfig {
|
||||
failure_threshold: self.cb_failure_threshold,
|
||||
success_threshold: self.cb_success_threshold,
|
||||
timeout_duration_secs: self.cb_timeout_duration_secs,
|
||||
window_duration_secs: self.cb_window_duration_secs,
|
||||
},
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -189,7 +215,20 @@ impl Router {
|
||||
prefill_policy = None,
|
||||
decode_policy = None,
|
||||
max_concurrent_requests = 64,
|
||||
cors_allowed_origins = vec![]
|
||||
cors_allowed_origins = vec![],
|
||||
// Retry defaults
|
||||
retry_max_retries = 3,
|
||||
retry_initial_backoff_ms = 100,
|
||||
retry_max_backoff_ms = 10_000,
|
||||
retry_backoff_multiplier = 2.0,
|
||||
retry_jitter_factor = 0.1,
|
||||
disable_retries = false,
|
||||
// Circuit breaker defaults
|
||||
cb_failure_threshold = 5,
|
||||
cb_success_threshold = 2,
|
||||
cb_timeout_duration_secs = 30,
|
||||
cb_window_duration_secs = 60,
|
||||
disable_circuit_breaker = false,
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
@@ -226,6 +265,17 @@ impl Router {
|
||||
decode_policy: Option<PolicyType>,
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
retry_max_retries: u32,
|
||||
retry_initial_backoff_ms: u64,
|
||||
retry_max_backoff_ms: u64,
|
||||
retry_backoff_multiplier: f32,
|
||||
retry_jitter_factor: f32,
|
||||
disable_retries: bool,
|
||||
cb_failure_threshold: u32,
|
||||
cb_success_threshold: u32,
|
||||
cb_timeout_duration_secs: u64,
|
||||
cb_window_duration_secs: u64,
|
||||
disable_circuit_breaker: bool,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -262,6 +312,17 @@ impl Router {
|
||||
decode_policy,
|
||||
max_concurrent_requests,
|
||||
cors_allowed_origins,
|
||||
retry_max_retries,
|
||||
retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms,
|
||||
retry_backoff_multiplier,
|
||||
retry_jitter_factor,
|
||||
disable_retries,
|
||||
cb_failure_threshold,
|
||||
cb_success_threshold,
|
||||
cb_timeout_duration_secs,
|
||||
cb_window_duration_secs,
|
||||
disable_circuit_breaker,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,28 @@ pub fn init_metrics() {
|
||||
"sgl_router_retries_total",
|
||||
"Total number of request retries by route"
|
||||
);
|
||||
describe_histogram!(
|
||||
"sgl_router_retry_backoff_duration_seconds",
|
||||
"Backoff duration in seconds by attempt index"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_retries_exhausted_total",
|
||||
"Total number of requests that exhausted retries by route"
|
||||
);
|
||||
|
||||
// Circuit breaker metrics
|
||||
describe_gauge!(
|
||||
"sgl_router_cb_state",
|
||||
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_cb_state_transitions_total",
|
||||
"Total number of circuit breaker state transitions by worker"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_cb_outcomes_total",
|
||||
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
|
||||
);
|
||||
|
||||
// Worker metrics
|
||||
describe_gauge!(
|
||||
@@ -186,6 +208,20 @@ impl RouterMetrics {
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_retry_backoff_duration(duration: Duration, attempt: u32) {
|
||||
histogram!("sgl_router_retry_backoff_duration_seconds",
|
||||
"attempt" => attempt.to_string()
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_retries_exhausted(route: &str) {
|
||||
counter!("sgl_router_retries_exhausted_total",
|
||||
"route" => route.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Worker metrics
|
||||
pub fn set_active_workers(count: usize) {
|
||||
gauge!("sgl_router_active_workers").set(count as f64);
|
||||
@@ -321,6 +357,31 @@ impl RouterMetrics {
|
||||
)
|
||||
.set(count as f64);
|
||||
}
|
||||
|
||||
// Circuit breaker metrics
|
||||
pub fn set_cb_state(worker: &str, state_code: u8) {
|
||||
gauge!("sgl_router_cb_state",
|
||||
"worker" => worker.to_string()
|
||||
)
|
||||
.set(state_code as f64);
|
||||
}
|
||||
|
||||
pub fn record_cb_state_transition(worker: &str, from: &str, to: &str) {
|
||||
counter!("sgl_router_cb_state_transitions_total",
|
||||
"worker" => worker.to_string(),
|
||||
"from" => from.to_string(),
|
||||
"to" => to.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_cb_outcome(worker: &str, outcome: &str) {
|
||||
counter!("sgl_router_cb_outcomes_total",
|
||||
"worker" => worker.to_string(),
|
||||
"outcome" => outcome.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -109,7 +109,7 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usi
|
||||
workers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, w)| w.is_healthy())
|
||||
.filter(|(_, w)| w.is_healthy() && w.circuit_breaker().can_execute())
|
||||
.map(|(idx, _)| idx)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -1845,7 +1845,7 @@ impl RouterTrait for PDRouter {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
use crate::policies::{CacheAwarePolicy, RandomPolicy};
|
||||
use crate::policies::RandomPolicy;
|
||||
|
||||
fn create_test_pd_router() -> PDRouter {
|
||||
let prefill_policy = Arc::new(RandomPolicy::new());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
||||
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory};
|
||||
use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
@@ -382,6 +382,33 @@ impl Router {
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
/// Select worker considering circuit breaker state
|
||||
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
|
||||
let workers = self.workers.read().ok()?;
|
||||
let available: Vec<Box<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.map(|w| w.clone_worker())
|
||||
.collect();
|
||||
if available.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let idx = self.policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone_worker())
|
||||
}
|
||||
|
||||
fn is_retryable_status(status: StatusCode) -> bool {
|
||||
matches!(
|
||||
status,
|
||||
StatusCode::REQUEST_TIMEOUT
|
||||
| StatusCode::TOO_MANY_REQUESTS
|
||||
| StatusCode::INTERNAL_SERVER_ERROR
|
||||
| StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE
|
||||
| StatusCode::GATEWAY_TIMEOUT
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn route_typed_request<
|
||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||
>(
|
||||
@@ -390,127 +417,70 @@ impl Router {
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
) -> Response {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
// Use retry config for per-worker retries
|
||||
let max_request_retries = self.retry_config.max_retries;
|
||||
// Total retries across all workers (2x to allow trying multiple workers)
|
||||
let max_total_retries = self.retry_config.max_retries * 2;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < max_total_retries {
|
||||
// Extract routing text directly from typed request
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
let is_stream = typed_req.is_stream();
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
|
||||
// Select worker based on text
|
||||
let worker_url = self.select_generate_worker_from_text(&text);
|
||||
if worker_url.is_empty() {
|
||||
RouterMetrics::record_request_error(route, "no_healthy_workers");
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&self.retry_config,
|
||||
// operation per attempt
|
||||
|_: u32| async {
|
||||
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
|
||||
Some(w) => w,
|
||||
None => {
|
||||
RouterMetrics::record_request_error(route, "no_available_workers");
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"No healthy workers available",
|
||||
"No available workers (all circuits open or unhealthy)",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let mut request_retries = 0;
|
||||
};
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < max_request_retries {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
RouterMetrics::record_retry(route);
|
||||
}
|
||||
|
||||
// Increment load before request if using RAII load tracking
|
||||
// Optional load tracking for cache-aware policy
|
||||
let load_incremented = if self.policy.name() == "cache_aware" {
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) {
|
||||
worker.increment_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
RouterMetrics::set_running_requests(worker.url(), worker.load());
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
headers,
|
||||
typed_req,
|
||||
route,
|
||||
&worker_url,
|
||||
worker.url(),
|
||||
is_stream,
|
||||
load_incremented,
|
||||
)
|
||||
.await;
|
||||
|
||||
worker.record_outcome(response.status().is_success());
|
||||
response
|
||||
},
|
||||
// should_retry predicate
|
||||
|res, _attempt| Self::is_retryable_status(res.status()),
|
||||
// on_backoff hook
|
||||
|delay, attempt| {
|
||||
RouterMetrics::record_retry(route);
|
||||
RouterMetrics::record_retry_backoff_duration(delay, attempt);
|
||||
},
|
||||
// on_exhausted hook
|
||||
|| RouterMetrics::record_retries_exhausted(route),
|
||||
)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
RouterMetrics::record_request(route);
|
||||
RouterMetrics::record_generate_duration(duration);
|
||||
return response;
|
||||
} else {
|
||||
let status = response.status();
|
||||
if status.is_client_error() && status != StatusCode::TOO_MANY_REQUESTS {
|
||||
RouterMetrics::record_request_error(route, "client_error");
|
||||
return response;
|
||||
}
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response = self.send_health_check(&worker_url).await;
|
||||
if health_response.status().is_success() {
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
return response;
|
||||
}
|
||||
} else if !Self::is_retryable_status(response.status()) {
|
||||
RouterMetrics::record_request_error(route, "non_retryable_error");
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
|
||||
route,
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
max_request_retries
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == max_request_retries {
|
||||
warn!(
|
||||
"Removing failed worker after typed request failures worker_url={}",
|
||||
worker_url
|
||||
);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
|
||||
let backoff_ms = (100u64 * (request_retries as u64)).min(1000);
|
||||
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||
}
|
||||
}
|
||||
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"All retry attempts failed",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper method to select worker from text using the policy
|
||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||
let workers = self.workers.read().unwrap();
|
||||
|
||||
match self.policy.select_worker(&workers, Some(text)) {
|
||||
Some(idx) => workers[idx].url().to_string(),
|
||||
None => {
|
||||
warn!("No healthy workers available");
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
// TODO (rui): Better accommodate to the Worker abstraction
|
||||
|
||||
@@ -48,6 +48,8 @@ impl TestContext {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1091,6 +1093,8 @@ mod error_tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1439,6 +1443,8 @@ mod pd_mode_tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1594,6 +1600,8 @@ mod request_id_tests {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -39,6 +39,8 @@ impl TestContext {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -40,6 +40,8 @@ impl TestContext {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -182,6 +182,8 @@ mod test_pd_routing {
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
|
||||
Reference in New Issue
Block a user