Files
sglang/sgl-router/src/core/token_bucket.rs

180 lines
4.9 KiB
Rust

use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace};
/// Token bucket for rate limiting
///
/// This implementation provides:
/// - Smooth rate limiting with configurable refill rate
/// - Burst capacity handling
/// - Fair queuing for waiting requests
#[derive(Clone)]
pub struct TokenBucket {
inner: Arc<Mutex<TokenBucketInner>>,
notify: Arc<Notify>,
capacity: f64,
refill_rate: f64, // tokens per second
}
struct TokenBucketInner {
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket
///
/// # Arguments
/// * `capacity` - Maximum number of tokens (burst capacity)
/// * `refill_rate` - Tokens added per second
pub fn new(capacity: usize, refill_rate: usize) -> Self {
let capacity = capacity as f64;
let refill_rate = refill_rate as f64;
let refill_rate = if refill_rate > 0.0 { refill_rate } else { 1.0 };
Self {
inner: Arc::new(Mutex::new(TokenBucketInner {
tokens: capacity,
last_refill: Instant::now(),
})),
notify: Arc::new(Notify::new()),
capacity,
refill_rate,
}
}
/// Try to acquire tokens immediately
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
trace!(
"Token bucket: {} tokens available, requesting {}",
inner.tokens,
tokens
);
if inner.tokens >= tokens {
inner.tokens -= tokens;
debug!(
"Token bucket: acquired {} tokens, {} remaining",
tokens, inner.tokens
);
Ok(())
} else {
Err(())
}
}
/// Acquire tokens, waiting if necessary
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
if self.try_acquire(tokens).await.is_ok() {
return Ok(());
}
let wait_time = {
let inner = self.inner.lock().await;
let tokens_needed = tokens - inner.tokens;
let wait_secs = tokens_needed / self.refill_rate;
Duration::from_secs_f64(wait_secs)
};
debug!(
"Token bucket: waiting {:?} for {} tokens",
wait_time, tokens
);
tokio::time::timeout(wait_time, async {
loop {
if self.try_acquire(tokens).await.is_ok() {
return;
}
tokio::select! {
_ = self.notify.notified() => {},
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
}
}
})
.await?;
Ok(())
}
/// Acquire tokens with custom timeout
pub async fn acquire_timeout(
&self,
tokens: f64,
timeout: Duration,
) -> Result<(), tokio::time::error::Elapsed> {
tokio::time::timeout(timeout, self.acquire(tokens)).await?
}
/// Return tokens to the bucket (for cancelled requests)
pub async fn return_tokens(&self, tokens: f64) {
let mut inner = self.inner.lock().await;
inner.tokens = (inner.tokens + tokens).min(self.capacity);
self.notify.notify_waiters();
debug!(
"Token bucket: returned {} tokens, {} available",
tokens, inner.tokens
);
}
/// Get current available tokens (for monitoring)
pub async fn available_tokens(&self) -> f64 {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
inner.tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_bucket_basic() {
let bucket = TokenBucket::new(10, 5);
assert!(bucket.try_acquire(5.0).await.is_ok());
assert!(bucket.try_acquire(5.0).await.is_ok());
assert!(bucket.try_acquire(1.0).await.is_err());
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(bucket.try_acquire(1.0).await.is_ok());
}
#[tokio::test]
async fn test_token_bucket_refill() {
let bucket = TokenBucket::new(10, 10);
assert!(bucket.try_acquire(10.0).await.is_ok());
tokio::time::sleep(Duration::from_millis(500)).await;
let available = bucket.available_tokens().await;
assert!((4.0..=6.0).contains(&available));
}
}