[router] add token bucket rate limiter (#9656)

This commit is contained in:
Chang Su
2025-08-26 10:36:26 -07:00
committed by GitHub
parent 3578eb1e9b
commit 90313fb09a
15 changed files with 533 additions and 10 deletions

View File

@@ -1,10 +1,19 @@
use axum::{extract::Request, http::HeaderValue, response::Response};
use axum::{
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{field::Empty, info_span, Span};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::server::AppState;
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
@@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) {
);
}
}
// ============ Concurrency Limiting with Queue Support ============
/// Request queue entry
pub struct QueuedRequest {
/// Time when the request was queued
queued_at: Instant,
/// Channel to send the permit back when acquired
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
}
/// Queue metrics for monitoring
#[derive(Debug, Default)]
pub struct QueueMetrics {
pub total_queued: std::sync::atomic::AtomicU64,
pub current_queued: std::sync::atomic::AtomicU64,
pub total_timeout: std::sync::atomic::AtomicU64,
pub total_rejected: std::sync::atomic::AtomicU64,
}
/// Queue processor that handles queued requests
pub struct QueueProcessor {
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
}
impl QueueProcessor {
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
) -> Self {
Self {
token_bucket,
queue_rx,
queue_timeout,
}
}
pub async fn run(mut self) {
info!("Starting concurrency queue processor");
// Process requests in a single task to reduce overhead
while let Some(queued) = self.queue_rx.recv().await {
// Check timeout immediately
let elapsed = queued.queued_at.elapsed();
if elapsed >= self.queue_timeout {
warn!("Request already timed out in queue");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
continue;
}
let remaining_timeout = self.queue_timeout - elapsed;
// Try to acquire token for this request
if self.token_bucket.try_acquire(1.0).await.is_ok() {
// Got token immediately
debug!("Queue: acquired token immediately for queued request");
let _ = queued.permit_tx.send(Ok(()));
} else {
// Need to wait for token
let token_bucket = self.token_bucket.clone();
// Spawn task only when we actually need to wait
tokio::spawn(async move {
if token_bucket
.acquire_timeout(1.0, remaining_timeout)
.await
.is_ok()
{
debug!("Queue: acquired token after waiting");
let _ = queued.permit_tx.send(Ok(()));
} else {
warn!("Queue: request timed out waiting for token");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
}
});
}
}
warn!("Concurrency queue processor shutting down");
}
}
/// State for the concurrency limiter
pub struct ConcurrencyLimiter {
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
}
impl ConcurrencyLimiter {
/// Create new concurrency limiter with optional queue
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_size: usize,
queue_timeout: Duration,
) -> (Self, Option<QueueProcessor>) {
if queue_size > 0 {
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
(
Self {
queue_tx: Some(queue_tx),
},
Some(processor),
)
} else {
(Self { queue_tx: None }, None)
}
}
}
/// Middleware function for concurrency limiting with optional queuing
pub async fn concurrency_limit_middleware(
State(app_state): State<Arc<AppState>>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let token_bucket = app_state.context.rate_limiter.clone();
// Try to acquire token immediately
if token_bucket.try_acquire(1.0).await.is_ok() {
debug!("Acquired token immediately");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
} else {
// No tokens available, try to queue if enabled
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
debug!("No tokens available, attempting to queue request");
// Create a channel for the token response
let (permit_tx, permit_rx) = oneshot::channel();
let queued = QueuedRequest {
queued_at: Instant::now(),
permit_tx,
};
// Try to send to queue
match queue_tx.try_send(queued) {
Ok(_) => {
// Wait for token from queue processor
match permit_rx.await {
Ok(Ok(())) => {
debug!("Acquired token from queue");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
}
Ok(Err(status)) => {
warn!("Queue returned error status: {}", status);
status.into_response()
}
Err(_) => {
error!("Queue response channel closed");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(_) => {
warn!("Request queue is full, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
} else {
warn!("No tokens available and queuing is disabled, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
}