[router] add auth middleware for api key auth (#10826)

This commit is contained in:
Chang Su
2025-09-23 16:07:34 -07:00
committed by GitHub
parent f4e3ebeb05
commit ee704e6265
6 changed files with 186 additions and 16 deletions

View File

@@ -1,12 +1,13 @@
use axum::{
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
response::IntoResponse, response::Response,
body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
@@ -17,6 +18,49 @@ pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics;
use crate::server::AppState;
#[derive(Clone)]
pub struct AuthConfig {
pub api_key: Option<String>,
}
/// Middleware to validate Bearer token against configured API key
/// Only active when router has an API key configured
pub async fn auth_middleware(
State(auth_config): State<AuthConfig>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(expected_key) = &auth_config.api_key {
// Extract Authorization header
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
match auth_header {
Some(header_value) if header_value.starts_with("Bearer ") => {
let token = &header_value[7..]; // Skip "Bearer "
// Use constant-time comparison to prevent timing attacks
let token_bytes = token.as_bytes();
let expected_bytes = expected_key.as_bytes();
// Check if lengths match first (this is not constant-time but necessary)
if token_bytes.len() != expected_bytes.len() {
return Err(StatusCode::UNAUTHORIZED);
}
// Constant-time comparison of the actual values
if token_bytes.ct_eq(expected_bytes).unwrap_u8() != 1 {
return Err(StatusCode::UNAUTHORIZED);
}
}
_ => return Err(StatusCode::UNAUTHORIZED),
}
}
Ok(next.run(request).await)
}
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") {

View File

@@ -4,7 +4,7 @@ use crate::{
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig},
middleware::{self, QueuedRequest, TokenBucket},
middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
policies::PolicyRegistry,
protocols::{
spec::{
@@ -275,6 +275,16 @@ async fn add_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
url
);
}
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result {
@@ -312,6 +322,16 @@ async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
config.url
);
}
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
match result {
@@ -423,6 +443,7 @@ pub struct ServerConfig {
pub fn build_app(
app_state: Arc<AppState>,
auth_config: AuthConfig,
max_payload_size: usize,
request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>,
@@ -448,6 +469,10 @@ pub fn build_app(
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
middleware::concurrency_limit_middleware,
))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
));
let public_routes = Router::new()
@@ -464,13 +489,21 @@ pub fn build_app(
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
.route("/get_loads", get(get_loads))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
));
let worker_routes = Router::new()
.route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker))
.route("/workers/{url}", delete(delete_worker));
.route("/workers/{url}", delete(delete_worker))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
));
Router::new()
.merge(protected_routes)
@@ -629,8 +662,13 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
]
});
let auth_config = AuthConfig {
api_key: config.router_config.api_key.clone(),
};
let app = build_app(
app_state,
auth_config,
config.max_payload_size,
request_id_headers,
config.router_config.cors_allowed_origins.clone(),