[router] add token bucket rate limiter (#9656)
This commit is contained in:
@@ -72,6 +72,12 @@ class RouterArgs:
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
@@ -402,6 +408,24 @@ class RouterArgs:
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
@@ -478,6 +502,21 @@ class RouterArgs:
|
||||
f"{prefix}max_concurrent_requests",
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
queue_size=getattr(
|
||||
args,
|
||||
f"{prefix}queue_size",
|
||||
RouterArgs.queue_size,
|
||||
),
|
||||
queue_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}queue_timeout_secs",
|
||||
RouterArgs.queue_timeout_secs,
|
||||
),
|
||||
rate_limit_tokens_per_second=getattr(
|
||||
args,
|
||||
f"{prefix}rate_limit_tokens_per_second",
|
||||
RouterArgs.rate_limit_tokens_per_second,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
||||
@@ -700,6 +739,9 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
),
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
queue_size=router_args.queue_size,
|
||||
queue_timeout_secs=router_args.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
retry_max_retries=router_args.retry_max_retries,
|
||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
||||
|
||||
@@ -64,7 +64,10 @@ class Router:
|
||||
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
||||
Default: 'sglang.ai/bootstrap-port'
|
||||
request_timeout_secs: Request timeout in seconds. Default: 600
|
||||
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
|
||||
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
|
||||
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
|
||||
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
|
||||
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
|
||||
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
||||
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
|
||||
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
|
||||
@@ -108,6 +111,9 @@ class Router:
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 256,
|
||||
queue_size: int = 100,
|
||||
queue_timeout_secs: int = 60,
|
||||
rate_limit_tokens_per_second: Optional[int] = None,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 5,
|
||||
retry_initial_backoff_ms: int = 50,
|
||||
@@ -169,6 +175,9 @@ class Router:
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
queue_size=queue_size,
|
||||
queue_timeout_secs=queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
|
||||
@@ -37,6 +37,12 @@ pub struct RouterConfig {
|
||||
pub request_id_headers: Option<Vec<String>>,
|
||||
/// Maximum concurrent requests allowed (for rate limiting)
|
||||
pub max_concurrent_requests: usize,
|
||||
/// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)
|
||||
pub queue_size: usize,
|
||||
/// Maximum time (in seconds) a request can wait in queue before timing out
|
||||
pub queue_timeout_secs: u64,
|
||||
/// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
pub rate_limit_tokens_per_second: Option<usize>,
|
||||
/// CORS allowed origins
|
||||
pub cors_allowed_origins: Vec<String>,
|
||||
/// Retry configuration
|
||||
@@ -320,6 +326,9 @@ impl Default for RouterConfig {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 256,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
@@ -466,6 +475,9 @@ mod tests {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -899,6 +911,9 @@ mod tests {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -956,6 +971,9 @@ mod tests {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -1009,6 +1027,9 @@ mod tests {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
|
||||
195
sgl-router/src/core/token_bucket.rs
Normal file
195
sgl-router/src/core/token_bucket.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
/// Token bucket for rate limiting
|
||||
///
|
||||
/// This implementation provides:
|
||||
/// - Smooth rate limiting with configurable refill rate
|
||||
/// - Burst capacity handling
|
||||
/// - Fair queuing for waiting requests
|
||||
#[derive(Clone)]
|
||||
pub struct TokenBucket {
|
||||
inner: Arc<Mutex<TokenBucketInner>>,
|
||||
notify: Arc<Notify>,
|
||||
capacity: f64,
|
||||
refill_rate: f64, // tokens per second
|
||||
}
|
||||
|
||||
struct TokenBucketInner {
|
||||
tokens: f64,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `capacity` - Maximum number of tokens (burst capacity)
|
||||
/// * `refill_rate` - Tokens added per second
|
||||
pub fn new(capacity: usize, refill_rate: usize) -> Self {
|
||||
let capacity = capacity as f64;
|
||||
let refill_rate = refill_rate as f64;
|
||||
|
||||
// Ensure refill_rate is not zero to prevent division by zero
|
||||
let refill_rate = if refill_rate > 0.0 {
|
||||
refill_rate
|
||||
} else {
|
||||
1.0 // Default to 1 token per second if zero
|
||||
};
|
||||
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(TokenBucketInner {
|
||||
tokens: capacity, // Start full
|
||||
last_refill: Instant::now(),
|
||||
})),
|
||||
notify: Arc::new(Notify::new()),
|
||||
capacity,
|
||||
refill_rate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to acquire tokens immediately
|
||||
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
|
||||
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
||||
inner.last_refill = now;
|
||||
|
||||
trace!(
|
||||
"Token bucket: {} tokens available, requesting {}",
|
||||
inner.tokens,
|
||||
tokens
|
||||
);
|
||||
|
||||
if inner.tokens >= tokens {
|
||||
inner.tokens -= tokens;
|
||||
debug!(
|
||||
"Token bucket: acquired {} tokens, {} remaining",
|
||||
tokens, inner.tokens
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire tokens, waiting if necessary
|
||||
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
|
||||
// First try to acquire immediately
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Calculate wait time
|
||||
let wait_time = {
|
||||
let inner = self.inner.lock().await;
|
||||
let tokens_needed = tokens - inner.tokens;
|
||||
let wait_secs = tokens_needed / self.refill_rate;
|
||||
Duration::from_secs_f64(wait_secs)
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Token bucket: waiting {:?} for {} tokens",
|
||||
wait_time, tokens
|
||||
);
|
||||
|
||||
// Wait for tokens to be available
|
||||
tokio::time::timeout(wait_time, async {
|
||||
loop {
|
||||
// Check if we can acquire now
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for notification or small interval
|
||||
tokio::select! {
|
||||
_ = self.notify.notified() => {},
|
||||
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Acquire tokens with custom timeout
|
||||
pub async fn acquire_timeout(
|
||||
&self,
|
||||
tokens: f64,
|
||||
timeout: Duration,
|
||||
) -> Result<(), tokio::time::error::Elapsed> {
|
||||
tokio::time::timeout(timeout, self.acquire(tokens)).await?
|
||||
}
|
||||
|
||||
/// Return tokens to the bucket (for cancelled requests)
|
||||
pub async fn return_tokens(&self, tokens: f64) {
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.tokens = (inner.tokens + tokens).min(self.capacity);
|
||||
self.notify.notify_waiters();
|
||||
debug!(
|
||||
"Token bucket: returned {} tokens, {} available",
|
||||
tokens, inner.tokens
|
||||
);
|
||||
}
|
||||
|
||||
/// Get current available tokens (for monitoring)
|
||||
pub async fn available_tokens(&self) -> f64 {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill before checking
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
|
||||
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
||||
inner.last_refill = now;
|
||||
|
||||
inner.tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_basic() {
|
||||
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
|
||||
|
||||
// Should succeed - bucket starts full
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
|
||||
// Should fail - no tokens left
|
||||
assert!(bucket.try_acquire(1.0).await.is_err());
|
||||
|
||||
// Wait for refill
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
|
||||
// Should have ~1.5 tokens now
|
||||
assert!(bucket.try_acquire(1.0).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_refill() {
|
||||
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
|
||||
|
||||
// Use all tokens
|
||||
assert!(bucket.try_acquire(10.0).await.is_ok());
|
||||
|
||||
// Wait for partial refill
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Should have ~5 tokens
|
||||
let available = bucket.available_tokens().await;
|
||||
assert!((4.0..=6.0).contains(&available));
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,9 @@ struct Router {
|
||||
health_check_endpoint: String,
|
||||
// IGW (Inference Gateway) configuration
|
||||
enable_igw: bool,
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -176,6 +179,9 @@ impl Router {
|
||||
log_level: self.log_level.clone(),
|
||||
request_id_headers: self.request_id_headers.clone(),
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
queue_size: self.queue_size,
|
||||
queue_timeout_secs: self.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: config::RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
@@ -190,8 +196,8 @@ impl Router {
|
||||
timeout_duration_secs: self.cb_timeout_duration_secs,
|
||||
window_duration_secs: self.cb_window_duration_secs,
|
||||
},
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
disable_retries: self.disable_retries,
|
||||
disable_circuit_breaker: self.disable_circuit_breaker,
|
||||
health_check: config::HealthCheckConfig {
|
||||
failure_threshold: self.health_failure_threshold,
|
||||
success_threshold: self.health_success_threshold,
|
||||
@@ -263,6 +269,9 @@ impl Router {
|
||||
health_check_endpoint = String::from("/health"),
|
||||
// IGW defaults
|
||||
enable_igw = false,
|
||||
queue_size = 100,
|
||||
queue_timeout_secs = 60,
|
||||
rate_limit_tokens_per_second = None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
@@ -317,6 +326,9 @@ impl Router {
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
enable_igw: bool,
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -370,6 +382,9 @@ impl Router {
|
||||
health_check_interval_secs,
|
||||
health_check_endpoint,
|
||||
enable_igw,
|
||||
queue_size,
|
||||
queue_timeout_secs,
|
||||
rate_limit_tokens_per_second,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -394,6 +394,8 @@ impl CliArgs {
|
||||
Some(self.request_id_headers.clone())
|
||||
},
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
queue_size: 100, // Default queue size
|
||||
queue_timeout_secs: 60, // Default timeout
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
@@ -418,6 +420,7 @@ impl CliArgs {
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
rate_limit_tokens_per_second: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
use axum::{extract::Request, http::HeaderValue, response::Response};
|
||||
use axum::{
|
||||
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
|
||||
response::IntoResponse, response::Response,
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tower::{Layer, Service};
|
||||
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||
use tracing::{field::Empty, info_span, Span};
|
||||
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::server::AppState;
|
||||
|
||||
/// Generate OpenAI-compatible request ID based on endpoint
|
||||
fn generate_request_id(path: &str) -> String {
|
||||
@@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Concurrency Limiting with Queue Support ============
|
||||
|
||||
/// Request queue entry
|
||||
pub struct QueuedRequest {
|
||||
/// Time when the request was queued
|
||||
queued_at: Instant,
|
||||
/// Channel to send the permit back when acquired
|
||||
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
|
||||
}
|
||||
|
||||
/// Queue metrics for monitoring
|
||||
#[derive(Debug, Default)]
|
||||
pub struct QueueMetrics {
|
||||
pub total_queued: std::sync::atomic::AtomicU64,
|
||||
pub current_queued: std::sync::atomic::AtomicU64,
|
||||
pub total_timeout: std::sync::atomic::AtomicU64,
|
||||
pub total_rejected: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
/// Queue processor that handles queued requests
|
||||
pub struct QueueProcessor {
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
}
|
||||
|
||||
impl QueueProcessor {
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
token_bucket,
|
||||
queue_rx,
|
||||
queue_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
info!("Starting concurrency queue processor");
|
||||
|
||||
// Process requests in a single task to reduce overhead
|
||||
while let Some(queued) = self.queue_rx.recv().await {
|
||||
// Check timeout immediately
|
||||
let elapsed = queued.queued_at.elapsed();
|
||||
if elapsed >= self.queue_timeout {
|
||||
warn!("Request already timed out in queue");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
continue;
|
||||
}
|
||||
|
||||
let remaining_timeout = self.queue_timeout - elapsed;
|
||||
|
||||
// Try to acquire token for this request
|
||||
if self.token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
// Got token immediately
|
||||
debug!("Queue: acquired token immediately for queued request");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
// Need to wait for token
|
||||
let token_bucket = self.token_bucket.clone();
|
||||
|
||||
// Spawn task only when we actually need to wait
|
||||
tokio::spawn(async move {
|
||||
if token_bucket
|
||||
.acquire_timeout(1.0, remaining_timeout)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
debug!("Queue: acquired token after waiting");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
warn!("Queue: request timed out waiting for token");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Concurrency queue processor shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
/// State for the concurrency limiter
|
||||
pub struct ConcurrencyLimiter {
|
||||
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
/// Create new concurrency limiter with optional queue
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_size: usize,
|
||||
queue_timeout: Duration,
|
||||
) -> (Self, Option<QueueProcessor>) {
|
||||
if queue_size > 0 {
|
||||
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
|
||||
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
|
||||
|
||||
(
|
||||
Self {
|
||||
queue_tx: Some(queue_tx),
|
||||
},
|
||||
Some(processor),
|
||||
)
|
||||
} else {
|
||||
(Self { queue_tx: None }, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware function for concurrency limiting with optional queuing
|
||||
pub async fn concurrency_limit_middleware(
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
request: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let token_bucket = app_state.context.rate_limiter.clone();
|
||||
|
||||
// Try to acquire token immediately
|
||||
if token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
debug!("Acquired token immediately");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
} else {
|
||||
// No tokens available, try to queue if enabled
|
||||
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
|
||||
debug!("No tokens available, attempting to queue request");
|
||||
|
||||
// Create a channel for the token response
|
||||
let (permit_tx, permit_rx) = oneshot::channel();
|
||||
|
||||
let queued = QueuedRequest {
|
||||
queued_at: Instant::now(),
|
||||
permit_tx,
|
||||
};
|
||||
|
||||
// Try to send to queue
|
||||
match queue_tx.try_send(queued) {
|
||||
Ok(_) => {
|
||||
// Wait for token from queue processor
|
||||
match permit_rx.await {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Acquired token from queue");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
}
|
||||
Ok(Err(status)) => {
|
||||
warn!("Queue returned error status: {}", status);
|
||||
status.into_response()
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Queue response channel closed");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Request queue is full, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("No tokens available and queuing is disabled, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
@@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level};
|
||||
pub struct AppContext {
|
||||
pub client: Client,
|
||||
pub router_config: RouterConfig,
|
||||
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||
pub rate_limiter: Arc<TokenBucket>,
|
||||
// Future dependencies can be added here
|
||||
}
|
||||
|
||||
@@ -34,12 +35,14 @@ impl AppContext {
|
||||
router_config: RouterConfig,
|
||||
client: Client,
|
||||
max_concurrent_requests: usize,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
) -> Self {
|
||||
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
||||
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
|
||||
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
|
||||
Self {
|
||||
client,
|
||||
router_config,
|
||||
concurrency_limiter,
|
||||
rate_limiter,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -48,6 +51,7 @@ impl AppContext {
|
||||
pub struct AppState {
|
||||
pub router: Arc<dyn RouterTrait>,
|
||||
pub context: Arc<AppContext>,
|
||||
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
|
||||
}
|
||||
|
||||
// Fallback handler for unmatched routes
|
||||
@@ -186,7 +190,11 @@ pub fn build_app(
|
||||
let protected_routes = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||
.route("/v1/completions", post(v1_completions));
|
||||
.route("/v1/completions", post(v1_completions))
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
crate::middleware::concurrency_limit_middleware,
|
||||
));
|
||||
|
||||
let public_routes = Router::new()
|
||||
.route("/liveness", get(liveness))
|
||||
@@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
config.router_config.rate_limit_tokens_per_second,
|
||||
));
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context).await?;
|
||||
|
||||
// Set up concurrency limiter with queue if configured
|
||||
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
|
||||
app_context.rate_limiter.clone(),
|
||||
config.router_config.queue_size,
|
||||
Duration::from_secs(config.router_config.queue_timeout_secs),
|
||||
);
|
||||
|
||||
// Start queue processor if enabled
|
||||
if let Some(processor) = processor {
|
||||
tokio::spawn(processor.run());
|
||||
info!(
|
||||
"Started request queue with size: {}, timeout: {}s",
|
||||
config.router_config.queue_size, config.router_config.queue_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
// Create app state with router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router: Arc::from(router),
|
||||
context: app_context.clone(),
|
||||
concurrency_queue_tx: limiter.queue_tx.clone(),
|
||||
});
|
||||
let router_arc = Arc::clone(&app_state.router);
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ impl TestContext {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
@@ -1088,6 +1091,9 @@ mod error_tests {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
@@ -1440,6 +1446,9 @@ mod pd_mode_tests {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
@@ -1596,6 +1605,9 @@ mod request_id_tests {
|
||||
log_level: None,
|
||||
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
|
||||
@@ -16,6 +16,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
config.clone(),
|
||||
reqwest::Client::new(),
|
||||
config.max_concurrent_requests,
|
||||
config.rate_limit_tokens_per_second,
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,14 @@ pub fn create_test_app(
|
||||
router_config.clone(),
|
||||
client,
|
||||
router_config.max_concurrent_requests,
|
||||
router_config.rate_limit_tokens_per_second,
|
||||
));
|
||||
|
||||
// Create AppState with the test router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router,
|
||||
context: app_context,
|
||||
concurrency_queue_tx: None, // No queue for tests
|
||||
});
|
||||
|
||||
// Configure request ID headers (use defaults if not specified)
|
||||
|
||||
@@ -36,6 +36,9 @@ impl TestContext {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
|
||||
@@ -37,6 +37,9 @@ impl TestContext {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
|
||||
@@ -178,6 +178,8 @@ mod test_pd_routing {
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
@@ -185,11 +187,12 @@ mod test_pd_routing {
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
rate_limit_tokens_per_second: None,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
let app_context =
|
||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
|
||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None);
|
||||
let app_context = std::sync::Arc::new(app_context);
|
||||
let result = RouterFactory::create_router(&app_context).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
Reference in New Issue
Block a user