diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 4adf9eb71..d1d80ec60 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -72,6 +72,12 @@ class RouterArgs: request_timeout_secs: int = 1800 # Max concurrent requests for rate limiting max_concurrent_requests: int = 256 + # Queue size for pending requests when max concurrent limit reached + queue_size: int = 100 + # Maximum time (in seconds) a request can wait in queue before timing out + queue_timeout_secs: int = 60 + # Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests + rate_limit_tokens_per_second: Optional[int] = None # CORS allowed origins cors_allowed_origins: List[str] = dataclasses.field(default_factory=list) # Retry configuration @@ -402,6 +408,24 @@ class RouterArgs: default=RouterArgs.max_concurrent_requests, help="Maximum number of concurrent requests allowed (for rate limiting)", ) + parser.add_argument( + f"--{prefix}queue-size", + type=int, + default=RouterArgs.queue_size, + help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)", + ) + parser.add_argument( + f"--{prefix}queue-timeout-secs", + type=int, + default=RouterArgs.queue_timeout_secs, + help="Maximum time (in seconds) a request can wait in queue before timing out", + ) + parser.add_argument( + f"--{prefix}rate-limit-tokens-per-second", + type=int, + default=RouterArgs.rate_limit_tokens_per_second, + help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests", + ) parser.add_argument( f"--{prefix}cors-allowed-origins", type=str, @@ -478,6 +502,21 @@ class RouterArgs: f"{prefix}max_concurrent_requests", RouterArgs.max_concurrent_requests, ), + queue_size=getattr( + args, + f"{prefix}queue_size", + RouterArgs.queue_size, + ), + queue_timeout_secs=getattr( + args, + f"{prefix}queue_timeout_secs", + RouterArgs.queue_timeout_secs, + ), + rate_limit_tokens_per_second=getattr( + args, + f"{prefix}rate_limit_tokens_per_second", + RouterArgs.rate_limit_tokens_per_second, + ), cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []), retry_max_retries=getattr(args, f"{prefix}retry_max_retries"), retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"), @@ -700,6 +739,9 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ), request_id_headers=router_args.request_id_headers, max_concurrent_requests=router_args.max_concurrent_requests, + queue_size=router_args.queue_size, + queue_timeout_secs=router_args.queue_timeout_secs, + rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second, cors_allowed_origins=router_args.cors_allowed_origins, retry_max_retries=router_args.retry_max_retries, retry_initial_backoff_ms=router_args.retry_initial_backoff_ms, diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 9abed9d96..d6c53e032 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -64,7 +64,10 @@ class Router: bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode). Default: 'sglang.ai/bootstrap-port' request_timeout_secs: Request timeout in seconds. Default: 600 - max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64 + max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256 + queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100 + queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60 + rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: [] health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3 health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2 @@ -108,6 +111,9 @@ class Router: prefill_policy: Optional[PolicyType] = None, decode_policy: Optional[PolicyType] = None, max_concurrent_requests: int = 256, + queue_size: int = 100, + queue_timeout_secs: int = 60, + rate_limit_tokens_per_second: Optional[int] = None, cors_allowed_origins: List[str] = None, retry_max_retries: int = 5, retry_initial_backoff_ms: int = 50, @@ -169,6 +175,9 @@ class Router: prefill_policy=prefill_policy, decode_policy=decode_policy, max_concurrent_requests=max_concurrent_requests, + queue_size=queue_size, + queue_timeout_secs=queue_timeout_secs, + rate_limit_tokens_per_second=rate_limit_tokens_per_second, cors_allowed_origins=cors_allowed_origins, retry_max_retries=retry_max_retries, retry_initial_backoff_ms=retry_initial_backoff_ms, diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 45e7e8d96..6afc3348e 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -37,6 +37,12 @@ pub struct RouterConfig { pub request_id_headers: Option>, /// Maximum concurrent requests allowed (for rate limiting) pub max_concurrent_requests: usize, + /// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately) + pub queue_size: usize, + /// Maximum time (in seconds) a request can wait in queue before timing out + pub queue_timeout_secs: u64, + /// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests + pub rate_limit_tokens_per_second: Option, /// CORS allowed origins pub cors_allowed_origins: Vec, /// Retry configuration @@ -320,6 +326,9 @@ impl Default for RouterConfig { log_level: None, request_id_headers: None, max_concurrent_requests: 256, + queue_size: 100, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, cors_allowed_origins: vec![], retry: RetryConfig::default(), circuit_breaker: CircuitBreakerConfig::default(), @@ -466,6 +475,9 @@ mod tests { disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), enable_igw: false, + queue_size: 100, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, }; let json = serde_json::to_string(&config).unwrap(); @@ -899,6 +911,9 @@ mod tests { disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), enable_igw: false, + queue_size: 100, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, }; assert!(config.mode.is_pd_mode()); @@ -956,6 +971,9 @@ mod tests { disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), enable_igw: false, + queue_size: 100, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, }; assert!(!config.mode.is_pd_mode()); @@ -1009,6 +1027,9 @@ mod tests { disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), enable_igw: false, + queue_size: 100, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index 101578119..4ccb05fb0 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -9,6 +9,7 @@ pub mod circuit_breaker; pub mod error; pub mod retry; +pub mod token_bucket; pub mod worker; // Re-export commonly used types at the module level diff --git a/sgl-router/src/core/token_bucket.rs b/sgl-router/src/core/token_bucket.rs new file mode 100644 index 000000000..65117331a --- /dev/null +++ b/sgl-router/src/core/token_bucket.rs @@ -0,0 +1,195 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, Notify}; +use tracing::{debug, trace}; + +/// Token bucket for rate limiting +/// +/// This implementation provides: +/// - Smooth rate limiting with configurable refill rate +/// - Burst capacity handling +/// - Fair queuing for waiting requests +#[derive(Clone)] +pub struct TokenBucket { + inner: Arc>, + notify: Arc, + capacity: f64, + refill_rate: f64, // tokens per second +} + +struct TokenBucketInner { + tokens: f64, + last_refill: Instant, +} + +impl TokenBucket { + /// Create a new token bucket + /// + /// # Arguments + /// * `capacity` - Maximum number of tokens (burst capacity) + /// * `refill_rate` - Tokens added per second + pub fn new(capacity: usize, refill_rate: usize) -> Self { + let capacity = capacity as f64; + let refill_rate = refill_rate as f64; + + // Ensure refill_rate is not zero to prevent division by zero + let refill_rate = if refill_rate > 0.0 { + refill_rate + } else { + 1.0 // Default to 1 token per second if zero + }; + + Self { + inner: Arc::new(Mutex::new(TokenBucketInner { + tokens: capacity, // Start full + last_refill: Instant::now(), + })), + notify: Arc::new(Notify::new()), + capacity, + refill_rate, + } + } + + /// Try to acquire tokens immediately + pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> { + let mut inner = self.inner.lock().await; + + // Refill tokens based on elapsed time + let now = Instant::now(); + let elapsed = now.duration_since(inner.last_refill).as_secs_f64(); + let refill_amount = elapsed * self.refill_rate; + + inner.tokens = (inner.tokens + refill_amount).min(self.capacity); + inner.last_refill = now; + + trace!( + "Token bucket: {} tokens available, requesting {}", + inner.tokens, + tokens + ); + + if inner.tokens >= tokens { + inner.tokens -= tokens; + debug!( + "Token bucket: acquired {} tokens, {} remaining", + tokens, inner.tokens + ); + Ok(()) + } else { + Err(()) + } + } + + /// Acquire tokens, waiting if necessary + pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> { + // First try to acquire immediately + if self.try_acquire(tokens).await.is_ok() { + return Ok(()); + } + + // Calculate wait time + let wait_time = { + let inner = self.inner.lock().await; + let tokens_needed = tokens - inner.tokens; + let wait_secs = tokens_needed / self.refill_rate; + Duration::from_secs_f64(wait_secs) + }; + + debug!( + "Token bucket: waiting {:?} for {} tokens", + wait_time, tokens + ); + + // Wait for tokens to be available + tokio::time::timeout(wait_time, async { + loop { + // Check if we can acquire now + if self.try_acquire(tokens).await.is_ok() { + return; + } + + // Wait for notification or small interval + tokio::select! { + _ = self.notify.notified() => {}, + _ = tokio::time::sleep(Duration::from_millis(10)) => {}, + } + } + }) + .await?; + + Ok(()) + } + + /// Acquire tokens with custom timeout + pub async fn acquire_timeout( + &self, + tokens: f64, + timeout: Duration, + ) -> Result<(), tokio::time::error::Elapsed> { + tokio::time::timeout(timeout, self.acquire(tokens)).await? + } + + /// Return tokens to the bucket (for cancelled requests) + pub async fn return_tokens(&self, tokens: f64) { + let mut inner = self.inner.lock().await; + inner.tokens = (inner.tokens + tokens).min(self.capacity); + self.notify.notify_waiters(); + debug!( + "Token bucket: returned {} tokens, {} available", + tokens, inner.tokens + ); + } + + /// Get current available tokens (for monitoring) + pub async fn available_tokens(&self) -> f64 { + let mut inner = self.inner.lock().await; + + // Refill before checking + let now = Instant::now(); + let elapsed = now.duration_since(inner.last_refill).as_secs_f64(); + let refill_amount = elapsed * self.refill_rate; + + inner.tokens = (inner.tokens + refill_amount).min(self.capacity); + inner.last_refill = now; + + inner.tokens + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_token_bucket_basic() { + let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second + + // Should succeed - bucket starts full + assert!(bucket.try_acquire(5.0).await.is_ok()); + assert!(bucket.try_acquire(5.0).await.is_ok()); + + // Should fail - no tokens left + assert!(bucket.try_acquire(1.0).await.is_err()); + + // Wait for refill + tokio::time::sleep(Duration::from_millis(300)).await; + + // Should have ~1.5 tokens now + assert!(bucket.try_acquire(1.0).await.is_ok()); + } + + #[tokio::test] + async fn test_token_bucket_refill() { + let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second + + // Use all tokens + assert!(bucket.try_acquire(10.0).await.is_ok()); + + // Wait for partial refill + tokio::time::sleep(Duration::from_millis(500)).await; + + // Should have ~5 tokens + let available = bucket.available_tokens().await; + assert!((4.0..=6.0).contains(&available)); + } +} diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 40d8ee162..03a616e90 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -85,6 +85,9 @@ struct Router { health_check_endpoint: String, // IGW (Inference Gateway) configuration enable_igw: bool, + queue_size: usize, + queue_timeout_secs: u64, + rate_limit_tokens_per_second: Option, } impl Router { @@ -176,6 +179,9 @@ impl Router { log_level: self.log_level.clone(), request_id_headers: self.request_id_headers.clone(), max_concurrent_requests: self.max_concurrent_requests, + queue_size: self.queue_size, + queue_timeout_secs: self.queue_timeout_secs, + rate_limit_tokens_per_second: self.rate_limit_tokens_per_second, cors_allowed_origins: self.cors_allowed_origins.clone(), retry: config::RetryConfig { max_retries: self.retry_max_retries, @@ -190,8 +196,8 @@ impl Router { timeout_duration_secs: self.cb_timeout_duration_secs, window_duration_secs: self.cb_window_duration_secs, }, - disable_retries: false, - disable_circuit_breaker: false, + disable_retries: self.disable_retries, + disable_circuit_breaker: self.disable_circuit_breaker, health_check: config::HealthCheckConfig { failure_threshold: self.health_failure_threshold, success_threshold: self.health_success_threshold, @@ -263,6 +269,9 @@ impl Router { health_check_endpoint = String::from("/health"), // IGW defaults enable_igw = false, + queue_size = 100, + queue_timeout_secs = 60, + rate_limit_tokens_per_second = None, ))] #[allow(clippy::too_many_arguments)] fn new( @@ -317,6 +326,9 @@ impl Router { health_check_interval_secs: u64, health_check_endpoint: String, enable_igw: bool, + queue_size: usize, + queue_timeout_secs: u64, + rate_limit_tokens_per_second: Option, ) -> PyResult { Ok(Router { host, @@ -370,6 +382,9 @@ impl Router { health_check_interval_secs, health_check_endpoint, enable_igw, + queue_size, + queue_timeout_secs, + rate_limit_tokens_per_second, }) } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index a2956e88c..1221d2b62 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -394,6 +394,8 @@ impl CliArgs { Some(self.request_id_headers.clone()) }, max_concurrent_requests: self.max_concurrent_requests, + queue_size: 100, // Default queue size + queue_timeout_secs: 60, // Default timeout cors_allowed_origins: self.cors_allowed_origins.clone(), retry: RetryConfig { max_retries: self.retry_max_retries, @@ -418,6 +420,7 @@ impl CliArgs { endpoint: self.health_check_endpoint.clone(), }, enable_igw: self.enable_igw, + rate_limit_tokens_per_second: None, }) } diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs index 26c22c768..abe137572 100644 --- a/sgl-router/src/middleware.rs +++ b/sgl-router/src/middleware.rs @@ -1,10 +1,19 @@ -use axum::{extract::Request, http::HeaderValue, response::Response}; +use axum::{ + extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next, + response::IntoResponse, response::Response, +}; use rand::Rng; use std::sync::Arc; +use std::time::Duration; use std::time::Instant; +use tokio::sync::{mpsc, oneshot}; use tower::{Layer, Service}; use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; -use tracing::{field::Empty, info_span, Span}; +use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; + +pub use crate::core::token_bucket::TokenBucket; + +use crate::server::AppState; /// Generate OpenAI-compatible request ID based on endpoint fn generate_request_id(path: &str) -> String { @@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) { ); } } + +// ============ Concurrency Limiting with Queue Support ============ + +/// Request queue entry +pub struct QueuedRequest { + /// Time when the request was queued + queued_at: Instant, + /// Channel to send the permit back when acquired + permit_tx: oneshot::Sender>, +} + +/// Queue metrics for monitoring +#[derive(Debug, Default)] +pub struct QueueMetrics { + pub total_queued: std::sync::atomic::AtomicU64, + pub current_queued: std::sync::atomic::AtomicU64, + pub total_timeout: std::sync::atomic::AtomicU64, + pub total_rejected: std::sync::atomic::AtomicU64, +} + +/// Queue processor that handles queued requests +pub struct QueueProcessor { + token_bucket: Arc, + queue_rx: mpsc::Receiver, + queue_timeout: Duration, +} + +impl QueueProcessor { + pub fn new( + token_bucket: Arc, + queue_rx: mpsc::Receiver, + queue_timeout: Duration, + ) -> Self { + Self { + token_bucket, + queue_rx, + queue_timeout, + } + } + + pub async fn run(mut self) { + info!("Starting concurrency queue processor"); + + // Process requests in a single task to reduce overhead + while let Some(queued) = self.queue_rx.recv().await { + // Check timeout immediately + let elapsed = queued.queued_at.elapsed(); + if elapsed >= self.queue_timeout { + warn!("Request already timed out in queue"); + let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT)); + continue; + } + + let remaining_timeout = self.queue_timeout - elapsed; + + // Try to acquire token for this request + if self.token_bucket.try_acquire(1.0).await.is_ok() { + // Got token immediately + debug!("Queue: acquired token immediately for queued request"); + let _ = queued.permit_tx.send(Ok(())); + } else { + // Need to wait for token + let token_bucket = self.token_bucket.clone(); + + // Spawn task only when we actually need to wait + tokio::spawn(async move { + if token_bucket + .acquire_timeout(1.0, remaining_timeout) + .await + .is_ok() + { + debug!("Queue: acquired token after waiting"); + let _ = queued.permit_tx.send(Ok(())); + } else { + warn!("Queue: request timed out waiting for token"); + let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT)); + } + }); + } + } + + warn!("Concurrency queue processor shutting down"); + } +} + +/// State for the concurrency limiter +pub struct ConcurrencyLimiter { + pub queue_tx: Option>, +} + +impl ConcurrencyLimiter { + /// Create new concurrency limiter with optional queue + pub fn new( + token_bucket: Arc, + queue_size: usize, + queue_timeout: Duration, + ) -> (Self, Option) { + if queue_size > 0 { + let (queue_tx, queue_rx) = mpsc::channel(queue_size); + let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout); + + ( + Self { + queue_tx: Some(queue_tx), + }, + Some(processor), + ) + } else { + (Self { queue_tx: None }, None) + } + } +} + +/// Middleware function for concurrency limiting with optional queuing +pub async fn concurrency_limit_middleware( + State(app_state): State>, + request: Request, + next: Next, +) -> Response { + let token_bucket = app_state.context.rate_limiter.clone(); + + // Try to acquire token immediately + if token_bucket.try_acquire(1.0).await.is_ok() { + debug!("Acquired token immediately"); + let response = next.run(request).await; + + // Return the token to the bucket + token_bucket.return_tokens(1.0).await; + + response + } else { + // No tokens available, try to queue if enabled + if let Some(queue_tx) = &app_state.concurrency_queue_tx { + debug!("No tokens available, attempting to queue request"); + + // Create a channel for the token response + let (permit_tx, permit_rx) = oneshot::channel(); + + let queued = QueuedRequest { + queued_at: Instant::now(), + permit_tx, + }; + + // Try to send to queue + match queue_tx.try_send(queued) { + Ok(_) => { + // Wait for token from queue processor + match permit_rx.await { + Ok(Ok(())) => { + debug!("Acquired token from queue"); + let response = next.run(request).await; + + // Return the token to the bucket + token_bucket.return_tokens(1.0).await; + + response + } + Ok(Err(status)) => { + warn!("Queue returned error status: {}", status); + status.into_response() + } + Err(_) => { + error!("Queue response channel closed"); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } + Err(_) => { + warn!("Request queue is full, returning 429"); + StatusCode::TOO_MANY_REQUESTS.into_response() + } + } + } else { + warn!("No tokens available and queuing is disabled, returning 429"); + StatusCode::TOO_MANY_REQUESTS.into_response() + } + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 7ca6b9388..e4af619c9 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,6 +1,7 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; +use crate::middleware::TokenBucket; use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; @@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level}; pub struct AppContext { pub client: Client, pub router_config: RouterConfig, - pub concurrency_limiter: Arc, + pub rate_limiter: Arc, // Future dependencies can be added here } @@ -34,12 +35,14 @@ impl AppContext { router_config: RouterConfig, client: Client, max_concurrent_requests: usize, + rate_limit_tokens_per_second: Option, ) -> Self { - let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests)); + let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests); + let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens)); Self { client, router_config, - concurrency_limiter, + rate_limiter, } } } @@ -48,6 +51,7 @@ impl AppContext { pub struct AppState { pub router: Arc, pub context: Arc, + pub concurrency_queue_tx: Option>, } // Fallback handler for unmatched routes @@ -186,7 +190,11 @@ pub fn build_app( let protected_routes = Router::new() .route("/generate", post(generate)) .route("/v1/chat/completions", post(v1_chat_completions)) - .route("/v1/completions", post(v1_completions)); + .route("/v1/completions", post(v1_completions)) + .route_layer(axum::middleware::from_fn_with_state( + app_state.clone(), + crate::middleware::concurrency_limit_middleware, + )); let public_routes = Router::new() .route("/liveness", get(liveness)) @@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Arc { config.clone(), reqwest::Client::new(), config.max_concurrent_requests, + config.rate_limit_tokens_per_second, )) } diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index d4961f9c3..554845363 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -19,12 +19,14 @@ pub fn create_test_app( router_config.clone(), client, router_config.max_concurrent_requests, + router_config.rate_limit_tokens_per_second, )); // Create AppState with the test router and context let app_state = Arc::new(AppState { router, context: app_context, + concurrency_queue_tx: None, // No queue for tests }); // Configure request ID headers (use defaults if not specified) diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index c62461754..2e91b82a6 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -36,6 +36,9 @@ impl TestContext { log_level: None, request_id_headers: None, max_concurrent_requests: 64, + queue_size: 0, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, cors_allowed_origins: vec![], retry: RetryConfig::default(), circuit_breaker: CircuitBreakerConfig::default(), diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 5e7828952..ce8f8cfdf 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -37,6 +37,9 @@ impl TestContext { log_level: None, request_id_headers: None, max_concurrent_requests: 64, + queue_size: 0, + queue_timeout_secs: 60, + rate_limit_tokens_per_second: None, cors_allowed_origins: vec![], retry: RetryConfig::default(), circuit_breaker: CircuitBreakerConfig::default(), diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 33091824d..401ee1119 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -178,6 +178,8 @@ mod test_pd_routing { log_level: None, request_id_headers: None, max_concurrent_requests: 64, + queue_size: 0, + queue_timeout_secs: 60, cors_allowed_origins: vec![], retry: RetryConfig::default(), circuit_breaker: CircuitBreakerConfig::default(), @@ -185,11 +187,12 @@ mod test_pd_routing { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + rate_limit_tokens_per_second: None, }; // Router creation will fail due to health checks, but config should be valid let app_context = - sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64); + sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None); let app_context = std::sync::Arc::new(app_context); let result = RouterFactory::create_router(&app_context).await; assert!(result.is_err());