[router] add token bucket rate limiter (#9656)

This commit is contained in:
Chang Su
2025-08-26 10:36:26 -07:00
committed by GitHub
parent 3578eb1e9b
commit 90313fb09a
15 changed files with 533 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
@@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level};
pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
pub rate_limiter: Arc<TokenBucket>,
// Future dependencies can be added here
}
@@ -34,12 +35,14 @@ impl AppContext {
router_config: RouterConfig,
client: Client,
max_concurrent_requests: usize,
rate_limit_tokens_per_second: Option<usize>,
) -> Self {
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
Self {
client,
router_config,
concurrency_limiter,
rate_limiter,
}
}
}
@@ -48,6 +51,7 @@ impl AppContext {
pub struct AppState {
pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>,
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
}
// Fallback handler for unmatched routes
@@ -186,7 +190,11 @@ pub fn build_app(
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
.route("/v1/completions", post(v1_completions));
.route("/v1/completions", post(v1_completions))
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
crate::middleware::concurrency_limit_middleware,
));
let public_routes = Router::new()
.route("/liveness", get(liveness))
@@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second,
));
// Create router with the context
let router = RouterFactory::create_router(&app_context).await?;
// Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(),
config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs),
);
// Start queue processor if enabled
if let Some(processor) = processor {
tokio::spawn(processor.run());
info!(
"Started request queue with size: {}, timeout: {}s",
config.router_config.queue_size, config.router_config.queue_timeout_secs
);
}
// Create app state with router and context
let app_state = Arc::new(AppState {
router: Arc::from(router),
context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(),
});
let router_arc = Arc::clone(&app_state.router);