180 lines
4.9 KiB
Rust
180 lines
4.9 KiB
Rust
use std::{
|
|
sync::Arc,
|
|
time::{Duration, Instant},
|
|
};
|
|
|
|
use tokio::sync::{Mutex, Notify};
|
|
use tracing::{debug, trace};
|
|
|
|
/// Token bucket for rate limiting
|
|
///
|
|
/// This implementation provides:
|
|
/// - Smooth rate limiting with configurable refill rate
|
|
/// - Burst capacity handling
|
|
/// - Fair queuing for waiting requests
|
|
#[derive(Clone)]
|
|
pub struct TokenBucket {
|
|
inner: Arc<Mutex<TokenBucketInner>>,
|
|
notify: Arc<Notify>,
|
|
capacity: f64,
|
|
refill_rate: f64, // tokens per second
|
|
}
|
|
|
|
struct TokenBucketInner {
|
|
tokens: f64,
|
|
last_refill: Instant,
|
|
}
|
|
|
|
impl TokenBucket {
|
|
/// Create a new token bucket
|
|
///
|
|
/// # Arguments
|
|
/// * `capacity` - Maximum number of tokens (burst capacity)
|
|
/// * `refill_rate` - Tokens added per second
|
|
pub fn new(capacity: usize, refill_rate: usize) -> Self {
|
|
let capacity = capacity as f64;
|
|
let refill_rate = refill_rate as f64;
|
|
|
|
let refill_rate = if refill_rate > 0.0 { refill_rate } else { 1.0 };
|
|
|
|
Self {
|
|
inner: Arc::new(Mutex::new(TokenBucketInner {
|
|
tokens: capacity,
|
|
last_refill: Instant::now(),
|
|
})),
|
|
notify: Arc::new(Notify::new()),
|
|
capacity,
|
|
refill_rate,
|
|
}
|
|
}
|
|
|
|
/// Try to acquire tokens immediately
|
|
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
|
|
let mut inner = self.inner.lock().await;
|
|
|
|
let now = Instant::now();
|
|
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
|
let refill_amount = elapsed * self.refill_rate;
|
|
|
|
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
|
inner.last_refill = now;
|
|
|
|
trace!(
|
|
"Token bucket: {} tokens available, requesting {}",
|
|
inner.tokens,
|
|
tokens
|
|
);
|
|
|
|
if inner.tokens >= tokens {
|
|
inner.tokens -= tokens;
|
|
debug!(
|
|
"Token bucket: acquired {} tokens, {} remaining",
|
|
tokens, inner.tokens
|
|
);
|
|
Ok(())
|
|
} else {
|
|
Err(())
|
|
}
|
|
}
|
|
|
|
/// Acquire tokens, waiting if necessary
|
|
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
|
|
if self.try_acquire(tokens).await.is_ok() {
|
|
return Ok(());
|
|
}
|
|
|
|
let wait_time = {
|
|
let inner = self.inner.lock().await;
|
|
let tokens_needed = tokens - inner.tokens;
|
|
let wait_secs = tokens_needed / self.refill_rate;
|
|
Duration::from_secs_f64(wait_secs)
|
|
};
|
|
|
|
debug!(
|
|
"Token bucket: waiting {:?} for {} tokens",
|
|
wait_time, tokens
|
|
);
|
|
|
|
tokio::time::timeout(wait_time, async {
|
|
loop {
|
|
if self.try_acquire(tokens).await.is_ok() {
|
|
return;
|
|
}
|
|
|
|
tokio::select! {
|
|
_ = self.notify.notified() => {},
|
|
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
|
|
}
|
|
}
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Acquire tokens with custom timeout
|
|
pub async fn acquire_timeout(
|
|
&self,
|
|
tokens: f64,
|
|
timeout: Duration,
|
|
) -> Result<(), tokio::time::error::Elapsed> {
|
|
tokio::time::timeout(timeout, self.acquire(tokens)).await?
|
|
}
|
|
|
|
/// Return tokens to the bucket (for cancelled requests)
|
|
pub async fn return_tokens(&self, tokens: f64) {
|
|
let mut inner = self.inner.lock().await;
|
|
inner.tokens = (inner.tokens + tokens).min(self.capacity);
|
|
self.notify.notify_waiters();
|
|
debug!(
|
|
"Token bucket: returned {} tokens, {} available",
|
|
tokens, inner.tokens
|
|
);
|
|
}
|
|
|
|
/// Get current available tokens (for monitoring)
|
|
pub async fn available_tokens(&self) -> f64 {
|
|
let mut inner = self.inner.lock().await;
|
|
|
|
let now = Instant::now();
|
|
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
|
let refill_amount = elapsed * self.refill_rate;
|
|
|
|
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
|
inner.last_refill = now;
|
|
|
|
inner.tokens
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_token_bucket_basic() {
|
|
let bucket = TokenBucket::new(10, 5);
|
|
|
|
assert!(bucket.try_acquire(5.0).await.is_ok());
|
|
assert!(bucket.try_acquire(5.0).await.is_ok());
|
|
|
|
assert!(bucket.try_acquire(1.0).await.is_err());
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
assert!(bucket.try_acquire(1.0).await.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_token_bucket_refill() {
|
|
let bucket = TokenBucket::new(10, 10);
|
|
|
|
assert!(bucket.try_acquire(10.0).await.is_ok());
|
|
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
|
|
let available = bucket.available_tokens().await;
|
|
assert!((4.0..=6.0).contains(&available));
|
|
}
|
|
}
|