[router] add auth middleware for api key auth (#10826)
This commit is contained in:
@@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server",
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"auth"] }
|
"auth"] }
|
||||||
serde_yaml = "0.9"
|
serde_yaml = "0.9"
|
||||||
|
subtle = "2.6"
|
||||||
|
|
||||||
# gRPC and Protobuf dependencies
|
# gRPC and Protobuf dependencies
|
||||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||||
|
|||||||
@@ -331,6 +331,79 @@ python -m sglang_router.launch_router \
|
|||||||
--prometheus-port 9090
|
--prometheus-port 9090
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### API Key Authentication
|
||||||
|
|
||||||
|
The router supports multi-level API key authentication for both the router itself and individual workers:
|
||||||
|
|
||||||
|
#### Router API Key
|
||||||
|
Protect access to the router endpoints:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--api-key "your-router-api-key" \
|
||||||
|
--worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
When router API key is set, clients must include the Bearer token:
|
||||||
|
```bash
|
||||||
|
curl -H "Authorization: Bearer your-router-api-key" http://localhost:8080/v1/chat/completions
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Worker API Keys
|
||||||
|
Workers can have their own API keys for authentication:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Workers specified in --worker-urls automatically inherit the router's API key
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--api-key "shared-api-key" \
|
||||||
|
--worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
# Both workers will use "shared-api-key" for authentication
|
||||||
|
|
||||||
|
# Adding workers dynamically WITHOUT inheriting router's key
|
||||||
|
curl -X POST http://localhost:8080/add_worker?url=http://worker3:8000
|
||||||
|
# WARNING: This worker has NO API key even though router has one!
|
||||||
|
|
||||||
|
# Adding workers with specific API keys dynamically
|
||||||
|
curl -X POST http://localhost:8080/add_worker?url=http://worker3:8000&api_key=worker3-specific-key
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Security Configurations
|
||||||
|
|
||||||
|
1. **No Authentication** (default):
|
||||||
|
- Router and workers accessible without keys
|
||||||
|
- Suitable for trusted environments
|
||||||
|
|
||||||
|
2. **Router-only Authentication**:
|
||||||
|
- Clients need key to access router
|
||||||
|
- Router can access workers freely
|
||||||
|
|
||||||
|
3. **Worker-only Authentication**:
|
||||||
|
- Router accessible without key
|
||||||
|
- Each worker requires authentication
|
||||||
|
```bash
|
||||||
|
# Add workers with their API keys
|
||||||
|
curl -X POST http://localhost:8080/add_worker?url=http://worker:8000&api_key=worker-key
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Full Authentication**:
|
||||||
|
- Router requires key from clients
|
||||||
|
- Each worker requires its own key
|
||||||
|
```bash
|
||||||
|
# Start router with its key
|
||||||
|
python -m sglang_router.launch_router --api-key "router-key"
|
||||||
|
|
||||||
|
# Add workers with their keys
|
||||||
|
curl -H "Authorization: Bearer router-key" \
|
||||||
|
-X POST http://localhost:8080/add_worker?url=http://worker:8000&api_key=worker-key
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Important Notes
|
||||||
|
|
||||||
|
- **Initial Workers**: Workers specified in `--worker-urls` automatically inherit the router's API key
|
||||||
|
- **Dynamic Workers**: When adding workers via API, you must explicitly specify their API keys - they do NOT inherit the router's key
|
||||||
|
- **Security Warning**: When adding workers without API keys while the router has one configured, a warning will be logged
|
||||||
|
- **Common Pitfall**: If router and workers use the same API key, you must still specify the key when adding workers dynamically
|
||||||
|
|
||||||
### Command Line Arguments Reference
|
### Command Line Arguments Reference
|
||||||
|
|
||||||
#### Service Discovery
|
#### Service Discovery
|
||||||
@@ -349,6 +422,9 @@ python -m sglang_router.launch_router \
|
|||||||
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
|
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
|
||||||
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
|
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
|
||||||
|
|
||||||
|
#### Authentication
|
||||||
|
- `--api-key`: API key for router authentication (clients must provide this as Bearer token)
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
### Build Process
|
### Build Process
|
||||||
|
|||||||
@@ -131,31 +131,29 @@ def test_dp_aware_worker_expansion_and_api_key(
|
|||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{router_url}/add_worker",
|
f"{router_url}/add_worker",
|
||||||
params={"url": worker_url, "api_key": api_key},
|
params={"url": worker_url, "api_key": api_key},
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
timeout=180,
|
timeout=180,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
r = requests.get(f"{router_url}/list_workers", timeout=30)
|
r = requests.get(
|
||||||
|
f"{router_url}/list_workers",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
urls = r.json().get("urls", [])
|
urls = r.json().get("urls", [])
|
||||||
assert len(urls) == 2
|
assert len(urls) == 2
|
||||||
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
||||||
|
|
||||||
# TODO: Router currently doesn't enforce API key authentication on incoming requests.
|
# Verify API key enforcement
|
||||||
# It only adds the API key to outgoing requests to workers.
|
# 1) Without Authorization -> Should get 401 Unauthorized
|
||||||
# Need to implement auth middleware to properly protect router endpoints.
|
|
||||||
# For now, both requests succeed (200) regardless of client authentication.
|
|
||||||
|
|
||||||
# Verify API key enforcement path-through
|
|
||||||
# 1) Without Authorization -> Currently 200 (should be 401 after auth middleware added)
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{router_url}/v1/completions",
|
f"{router_url}/v1/completions",
|
||||||
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
assert (
|
assert r.status_code == 401
|
||||||
r.status_code == 200
|
|
||||||
) # TODO: Change to 401 after auth middleware implementation
|
|
||||||
|
|
||||||
# 2) With correct Authorization -> 200
|
# 2) With correct Authorization -> 200
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
|
body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
|
||||||
response::IntoResponse, response::Response,
|
http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
|
||||||
};
|
};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use subtle::ConstantTimeEq;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||||
@@ -17,6 +18,49 @@ pub use crate::core::token_bucket::TokenBucket;
|
|||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::server::AppState;
|
use crate::server::AppState;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AuthConfig {
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware to validate Bearer token against configured API key
|
||||||
|
/// Only active when router has an API key configured
|
||||||
|
pub async fn auth_middleware(
|
||||||
|
State(auth_config): State<AuthConfig>,
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
if let Some(expected_key) = &auth_config.api_key {
|
||||||
|
// Extract Authorization header
|
||||||
|
let auth_header = request
|
||||||
|
.headers()
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.and_then(|h| h.to_str().ok());
|
||||||
|
|
||||||
|
match auth_header {
|
||||||
|
Some(header_value) if header_value.starts_with("Bearer ") => {
|
||||||
|
let token = &header_value[7..]; // Skip "Bearer "
|
||||||
|
// Use constant-time comparison to prevent timing attacks
|
||||||
|
let token_bytes = token.as_bytes();
|
||||||
|
let expected_bytes = expected_key.as_bytes();
|
||||||
|
|
||||||
|
// Check if lengths match first (this is not constant-time but necessary)
|
||||||
|
if token_bytes.len() != expected_bytes.len() {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant-time comparison of the actual values
|
||||||
|
if token_bytes.ct_eq(expected_bytes).unwrap_u8() != 1 {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(StatusCode::UNAUTHORIZED),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate OpenAI-compatible request ID based on endpoint
|
/// Generate OpenAI-compatible request ID based on endpoint
|
||||||
fn generate_request_id(path: &str) -> String {
|
fn generate_request_id(path: &str) -> String {
|
||||||
let prefix = if path.contains("/chat/completions") {
|
let prefix = if path.contains("/chat/completions") {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use crate::{
|
|||||||
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
||||||
logging::{self, LoggingConfig},
|
logging::{self, LoggingConfig},
|
||||||
metrics::{self, PrometheusConfig},
|
metrics::{self, PrometheusConfig},
|
||||||
middleware::{self, QueuedRequest, TokenBucket},
|
middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
|
||||||
policies::PolicyRegistry,
|
policies::PolicyRegistry,
|
||||||
protocols::{
|
protocols::{
|
||||||
spec::{
|
spec::{
|
||||||
@@ -275,6 +275,16 @@ async fn add_worker(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
|
// Warn if router has API key but worker is being added without one
|
||||||
|
if state.context.router_config.api_key.is_some() && api_key.is_none() {
|
||||||
|
warn!(
|
||||||
|
"Adding worker {} without API key while router has API key configured. \
|
||||||
|
Worker will be accessible without authentication. \
|
||||||
|
If the worker requires the same API key as the router, please specify it explicitly.",
|
||||||
|
url
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
|
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -312,6 +322,16 @@ async fn create_worker(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(config): Json<WorkerConfigRequest>,
|
Json(config): Json<WorkerConfigRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
|
// Warn if router has API key but worker is being added without one
|
||||||
|
if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
|
||||||
|
warn!(
|
||||||
|
"Adding worker {} without API key while router has API key configured. \
|
||||||
|
Worker will be accessible without authentication. \
|
||||||
|
If the worker requires the same API key as the router, please specify it explicitly.",
|
||||||
|
config.url
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
|
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -423,6 +443,7 @@ pub struct ServerConfig {
|
|||||||
|
|
||||||
pub fn build_app(
|
pub fn build_app(
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
|
auth_config: AuthConfig,
|
||||||
max_payload_size: usize,
|
max_payload_size: usize,
|
||||||
request_id_headers: Vec<String>,
|
request_id_headers: Vec<String>,
|
||||||
cors_allowed_origins: Vec<String>,
|
cors_allowed_origins: Vec<String>,
|
||||||
@@ -448,6 +469,10 @@ pub fn build_app(
|
|||||||
.route_layer(axum::middleware::from_fn_with_state(
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
app_state.clone(),
|
app_state.clone(),
|
||||||
middleware::concurrency_limit_middleware,
|
middleware::concurrency_limit_middleware,
|
||||||
|
))
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
|
auth_config.clone(),
|
||||||
|
middleware::auth_middleware,
|
||||||
));
|
));
|
||||||
|
|
||||||
let public_routes = Router::new()
|
let public_routes = Router::new()
|
||||||
@@ -464,13 +489,21 @@ pub fn build_app(
|
|||||||
.route("/remove_worker", post(remove_worker))
|
.route("/remove_worker", post(remove_worker))
|
||||||
.route("/list_workers", get(list_workers))
|
.route("/list_workers", get(list_workers))
|
||||||
.route("/flush_cache", post(flush_cache))
|
.route("/flush_cache", post(flush_cache))
|
||||||
.route("/get_loads", get(get_loads));
|
.route("/get_loads", get(get_loads))
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
|
auth_config.clone(),
|
||||||
|
middleware::auth_middleware,
|
||||||
|
));
|
||||||
|
|
||||||
let worker_routes = Router::new()
|
let worker_routes = Router::new()
|
||||||
.route("/workers", post(create_worker))
|
.route("/workers", post(create_worker))
|
||||||
.route("/workers", get(list_workers_rest))
|
.route("/workers", get(list_workers_rest))
|
||||||
.route("/workers/{url}", get(get_worker))
|
.route("/workers/{url}", get(get_worker))
|
||||||
.route("/workers/{url}", delete(delete_worker));
|
.route("/workers/{url}", delete(delete_worker))
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
|
auth_config.clone(),
|
||||||
|
middleware::auth_middleware,
|
||||||
|
));
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.merge(protected_routes)
|
.merge(protected_routes)
|
||||||
@@ -629,8 +662,13 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let auth_config = AuthConfig {
|
||||||
|
api_key: config.router_config.api_key.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
let app = build_app(
|
let app = build_app(
|
||||||
app_state,
|
app_state,
|
||||||
|
auth_config,
|
||||||
config.max_payload_size,
|
config.max_payload_size,
|
||||||
request_id_headers,
|
request_id_headers,
|
||||||
config.router_config.cors_allowed_origins.clone(),
|
config.router_config.cors_allowed_origins.clone(),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use axum::Router;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::RouterConfig,
|
config::RouterConfig,
|
||||||
|
middleware::AuthConfig,
|
||||||
routers::RouterTrait,
|
routers::RouterTrait,
|
||||||
server::{build_app, AppContext, AppState},
|
server::{build_app, AppContext, AppState},
|
||||||
};
|
};
|
||||||
@@ -43,9 +44,15 @@ pub fn create_test_app(
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Create auth config from router config
|
||||||
|
let auth_config = AuthConfig {
|
||||||
|
api_key: router_config.api_key.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
// Use the actual server's build_app function
|
// Use the actual server's build_app function
|
||||||
build_app(
|
build_app(
|
||||||
app_state,
|
app_state,
|
||||||
|
auth_config,
|
||||||
router_config.max_payload_size,
|
router_config.max_payload_size,
|
||||||
request_id_headers,
|
request_id_headers,
|
||||||
router_config.cors_allowed_origins.clone(),
|
router_config.cors_allowed_origins.clone(),
|
||||||
@@ -79,9 +86,15 @@ pub fn create_test_app_with_context(
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Create auth config from router config
|
||||||
|
let auth_config = AuthConfig {
|
||||||
|
api_key: router_config.api_key.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
// Use the actual server's build_app function
|
// Use the actual server's build_app function
|
||||||
build_app(
|
build_app(
|
||||||
app_state,
|
app_state,
|
||||||
|
auth_config,
|
||||||
router_config.max_payload_size,
|
router_config.max_payload_size,
|
||||||
request_id_headers,
|
request_id_headers,
|
||||||
router_config.cors_allowed_origins.clone(),
|
router_config.cors_allowed_origins.clone(),
|
||||||
|
|||||||
Reference in New Issue
Block a user