[router] add token bucket rate limiter (#9656)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
@@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level};
|
||||
pub struct AppContext {
|
||||
pub client: Client,
|
||||
pub router_config: RouterConfig,
|
||||
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||
pub rate_limiter: Arc<TokenBucket>,
|
||||
// Future dependencies can be added here
|
||||
}
|
||||
|
||||
@@ -34,12 +35,14 @@ impl AppContext {
|
||||
router_config: RouterConfig,
|
||||
client: Client,
|
||||
max_concurrent_requests: usize,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
) -> Self {
|
||||
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
||||
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
|
||||
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
|
||||
Self {
|
||||
client,
|
||||
router_config,
|
||||
concurrency_limiter,
|
||||
rate_limiter,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -48,6 +51,7 @@ impl AppContext {
|
||||
pub struct AppState {
|
||||
pub router: Arc<dyn RouterTrait>,
|
||||
pub context: Arc<AppContext>,
|
||||
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
|
||||
}
|
||||
|
||||
// Fallback handler for unmatched routes
|
||||
@@ -186,7 +190,11 @@ pub fn build_app(
|
||||
let protected_routes = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||
.route("/v1/completions", post(v1_completions));
|
||||
.route("/v1/completions", post(v1_completions))
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
crate::middleware::concurrency_limit_middleware,
|
||||
));
|
||||
|
||||
let public_routes = Router::new()
|
||||
.route("/liveness", get(liveness))
|
||||
@@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
config.router_config.rate_limit_tokens_per_second,
|
||||
));
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context).await?;
|
||||
|
||||
// Set up concurrency limiter with queue if configured
|
||||
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
|
||||
app_context.rate_limiter.clone(),
|
||||
config.router_config.queue_size,
|
||||
Duration::from_secs(config.router_config.queue_timeout_secs),
|
||||
);
|
||||
|
||||
// Start queue processor if enabled
|
||||
if let Some(processor) = processor {
|
||||
tokio::spawn(processor.run());
|
||||
info!(
|
||||
"Started request queue with size: {}, timeout: {}s",
|
||||
config.router_config.queue_size, config.router_config.queue_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
// Create app state with router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router: Arc::from(router),
|
||||
context: app_context.clone(),
|
||||
concurrency_queue_tx: limiter.queue_tx.clone(),
|
||||
});
|
||||
let router_arc = Arc::clone(&app_state.router);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user