From 9a0cc2e90e61942483c6e073e9af42cec75364df Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Thu, 23 Jan 2025 20:30:31 -0800 Subject: [PATCH] [router] Forward all request headers from router to workers (#3070) --- scripts/killall_sglang.sh | 9 ++++ sgl-router/py_test/test_launch_server.py | 56 +++++++++++++++++++ sgl-router/src/router.rs | 68 ++++++++++++++++++------ sgl-router/src/server.rs | 24 +++++---- 4 files changed, 132 insertions(+), 25 deletions(-) diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 53d08703e..163a60f18 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,14 @@ #!/bin/bash +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y lsof +else + apt-get update + apt-get install -y lsof +fi + # Show current GPU status nvidia-smi diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py index e11602933..80659fc4f 100644 --- a/sgl-router/py_test/test_launch_server.py +++ b/sgl-router/py_test/test_launch_server.py @@ -22,6 +22,7 @@ def popen_launch_router( timeout: float, policy: str = "cache_aware", max_payload_size: int = None, + api_key: str = None, ): """ Launch the router server process. @@ -33,6 +34,7 @@ def popen_launch_router( timeout: Server launch timeout policy: Router policy, one of "cache_aware", "round_robin", "random" max_payload_size: Maximum payload size in bytes + api_key: API key for the router """ _, host, port = base_url.split(":") host = host[2:] @@ -55,6 +57,9 @@ def popen_launch_router( policy, ] + if api_key is not None: + command.extend(["--api-key", api_key]) + if max_payload_size is not None: command.extend(["--router-max-payload-size", str(max_payload_size)]) @@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase): f"1.2MB payload should fail with 413 but got status {response.status_code}", ) + def test_5_api_key(self): + print("Running test_5_api_key...") + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + api_key="correct_api_key", + ) + + # # Test case 1: request without api key should fail + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request without api key should fail with 401", + ) + + # Test case 2: request with invalid api key should fail + with requests.Session() as session: + response = requests.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + headers={"Authorization": "Bearer 123"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request with invalid api key should fail with 401", + ) + + # Test case 3: request with correct api key should succeed + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is ", "temperature": 0}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, 200, "Request with correct api key should succeed" + ) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index a189ff9eb..5ee34c598 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -12,6 +12,18 @@ use std::thread; use std::time::Duration; use tokio; +fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + #[derive(Debug)] pub enum Router { RoundRobin { @@ -303,8 +315,18 @@ impl Router { client: &reqwest::Client, worker_url: &str, route: &str, + req: &HttpRequest, ) -> HttpResponse { - match client.get(format!("{}{}", worker_url, route)).send().await { + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -322,7 +344,12 @@ impl Router { } } - pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { const MAX_REQUEST_RETRIES: u32 = 3; const MAX_TOTAL_RETRIES: u32 = 6; let mut total_retries = 0; @@ -338,10 +365,17 @@ impl Router { info!("Retrying request after {} failed attempts", total_retries); } - let response = self.send_request(client, &worker_url, route).await; + let response = self.send_request(client, &worker_url, route, req).await; if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( @@ -496,19 +530,16 @@ impl Router { .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); - let res = match client + let mut request_builder = client .post(format!("{}{}", worker_url, route)) - .header( - "Content-Type", - req.headers() - .get("Content-Type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json"), - ) - .body(body.to_vec()) - .send() - .await - { + .body(body.to_vec()); + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + let res = match request_builder.send().await { Ok(res) => res, Err(_) => return HttpResponse::InternalServerError().finish(), }; @@ -596,6 +627,13 @@ impl Router { if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index e3587389e..0706c57c0 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -26,33 +26,37 @@ impl AppState { } #[get("/health")] -async fn health(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/health").await +async fn health(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/health", &req) + .await } #[get("/health_generate")] -async fn health_generate(data: web::Data) -> impl Responder { +async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/health_generate") + .route_to_first(&data.client, "/health_generate", &req) .await } #[get("/get_server_info")] -async fn get_server_info(data: web::Data) -> impl Responder { +async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_server_info") + .route_to_first(&data.client, "/get_server_info", &req) .await } #[get("/v1/models")] -async fn v1_models(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/v1/models").await +async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/v1/models", &req) + .await } #[get("/get_model_info")] -async fn get_model_info(data: web::Data) -> impl Responder { +async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_model_info") + .route_to_first(&data.client, "/get_model_info", &req) .await }