[router]: Add Embedding routing logic (#10129)
Signed-off-by: Jintao Zhang <zhangjintao9020@gmail.com> Co-authored-by: Waël Boukhobza <wawa_wael@live.fr>
This commit is contained in:
@@ -3,6 +3,7 @@ use axum::{
|
||||
response::IntoResponse, response::Response,
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
@@ -13,6 +14,7 @@ use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::server::AppState;
|
||||
|
||||
/// Generate OpenAI-compatible request ID based on endpoint
|
||||
@@ -441,6 +443,11 @@ pub async fn concurrency_limit_middleware(
|
||||
request: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// Static counter for embeddings queue size
|
||||
static EMBEDDINGS_QUEUE_SIZE: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
// Identify if this is an embeddings request based on path
|
||||
let is_embeddings = request.uri().path().contains("/v1/embeddings");
|
||||
let token_bucket = app_state.context.rate_limiter.clone();
|
||||
|
||||
// Try to acquire token immediately
|
||||
@@ -468,10 +475,23 @@ pub async fn concurrency_limit_middleware(
|
||||
// Try to send to queue
|
||||
match queue_tx.try_send(queued) {
|
||||
Ok(_) => {
|
||||
// On successful enqueue, update embeddings queue gauge if applicable
|
||||
if is_embeddings {
|
||||
let new_val = EMBEDDINGS_QUEUE_SIZE.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
|
||||
// Wait for token from queue processor
|
||||
match permit_rx.await {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Acquired token from queue");
|
||||
// Dequeue for embeddings
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
@@ -481,10 +501,22 @@ pub async fn concurrency_limit_middleware(
|
||||
}
|
||||
Ok(Err(status)) => {
|
||||
warn!("Queue returned error status: {}", status);
|
||||
// Dequeue for embeddings on error
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
status.into_response()
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Queue response channel closed");
|
||||
// Dequeue for embeddings on channel error
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user