[router] add token bucket rate limiter (#9656)
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
use axum::{extract::Request, http::HeaderValue, response::Response};
|
||||
use axum::{
|
||||
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
|
||||
response::IntoResponse, response::Response,
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tower::{Layer, Service};
|
||||
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||
use tracing::{field::Empty, info_span, Span};
|
||||
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::server::AppState;
|
||||
|
||||
/// Generate OpenAI-compatible request ID based on endpoint
|
||||
fn generate_request_id(path: &str) -> String {
|
||||
@@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Concurrency Limiting with Queue Support ============
|
||||
|
||||
/// Request queue entry
|
||||
pub struct QueuedRequest {
|
||||
/// Time when the request was queued
|
||||
queued_at: Instant,
|
||||
/// Channel to send the permit back when acquired
|
||||
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
|
||||
}
|
||||
|
||||
/// Queue metrics for monitoring
|
||||
#[derive(Debug, Default)]
|
||||
pub struct QueueMetrics {
|
||||
pub total_queued: std::sync::atomic::AtomicU64,
|
||||
pub current_queued: std::sync::atomic::AtomicU64,
|
||||
pub total_timeout: std::sync::atomic::AtomicU64,
|
||||
pub total_rejected: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
/// Queue processor that handles queued requests
|
||||
pub struct QueueProcessor {
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
}
|
||||
|
||||
impl QueueProcessor {
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
token_bucket,
|
||||
queue_rx,
|
||||
queue_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
info!("Starting concurrency queue processor");
|
||||
|
||||
// Process requests in a single task to reduce overhead
|
||||
while let Some(queued) = self.queue_rx.recv().await {
|
||||
// Check timeout immediately
|
||||
let elapsed = queued.queued_at.elapsed();
|
||||
if elapsed >= self.queue_timeout {
|
||||
warn!("Request already timed out in queue");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
continue;
|
||||
}
|
||||
|
||||
let remaining_timeout = self.queue_timeout - elapsed;
|
||||
|
||||
// Try to acquire token for this request
|
||||
if self.token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
// Got token immediately
|
||||
debug!("Queue: acquired token immediately for queued request");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
// Need to wait for token
|
||||
let token_bucket = self.token_bucket.clone();
|
||||
|
||||
// Spawn task only when we actually need to wait
|
||||
tokio::spawn(async move {
|
||||
if token_bucket
|
||||
.acquire_timeout(1.0, remaining_timeout)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
debug!("Queue: acquired token after waiting");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
warn!("Queue: request timed out waiting for token");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Concurrency queue processor shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
/// State for the concurrency limiter
|
||||
pub struct ConcurrencyLimiter {
|
||||
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
/// Create new concurrency limiter with optional queue
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_size: usize,
|
||||
queue_timeout: Duration,
|
||||
) -> (Self, Option<QueueProcessor>) {
|
||||
if queue_size > 0 {
|
||||
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
|
||||
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
|
||||
|
||||
(
|
||||
Self {
|
||||
queue_tx: Some(queue_tx),
|
||||
},
|
||||
Some(processor),
|
||||
)
|
||||
} else {
|
||||
(Self { queue_tx: None }, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware function for concurrency limiting with optional queuing
|
||||
pub async fn concurrency_limit_middleware(
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
request: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let token_bucket = app_state.context.rate_limiter.clone();
|
||||
|
||||
// Try to acquire token immediately
|
||||
if token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
debug!("Acquired token immediately");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
} else {
|
||||
// No tokens available, try to queue if enabled
|
||||
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
|
||||
debug!("No tokens available, attempting to queue request");
|
||||
|
||||
// Create a channel for the token response
|
||||
let (permit_tx, permit_rx) = oneshot::channel();
|
||||
|
||||
let queued = QueuedRequest {
|
||||
queued_at: Instant::now(),
|
||||
permit_tx,
|
||||
};
|
||||
|
||||
// Try to send to queue
|
||||
match queue_tx.try_send(queued) {
|
||||
Ok(_) => {
|
||||
// Wait for token from queue processor
|
||||
match permit_rx.await {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Acquired token from queue");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
}
|
||||
Ok(Err(status)) => {
|
||||
warn!("Queue returned error status: {}", status);
|
||||
status.into_response()
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Queue response channel closed");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Request queue is full, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("No tokens available and queuing is disabled, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user