From da53e13cbb11c2491acf0a9bac49f9e568aec10e Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 15 Aug 2025 11:01:47 -0700 Subject: [PATCH] [router] preserve original worker response header in router (#9236) --- sgl-router/src/routers/header_utils.rs | 53 ++++++++++++++++++++++ sgl-router/src/routers/mod.rs | 1 + sgl-router/src/routers/pd_router.rs | 63 +++++++++++++++++++------- sgl-router/src/routers/router.rs | 55 +++++++++++++--------- 4 files changed, 135 insertions(+), 37 deletions(-) create mode 100644 sgl-router/src/routers/header_utils.rs diff --git a/sgl-router/src/routers/header_utils.rs b/sgl-router/src/routers/header_utils.rs new file mode 100644 index 000000000..92ce5d3b6 --- /dev/null +++ b/sgl-router/src/routers/header_utils.rs @@ -0,0 +1,53 @@ +use axum::body::Body; +use axum::extract::Request; +use axum::http::{HeaderMap, HeaderName, HeaderValue}; + +/// Copy request headers to a Vec of name-value string pairs +/// Used for forwarding headers to backend workers +pub fn copy_request_headers(req: &Request) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + // Convert header value to string, skipping non-UTF8 headers + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + +/// Convert headers from reqwest Response to axum HeaderMap +/// Filters out hop-by-hop headers that shouldn't be forwarded +pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap { + let mut headers = HeaderMap::new(); + + for (name, value) in reqwest_headers.iter() { + // Skip hop-by-hop headers that shouldn't be forwarded + let name_str = name.as_str().to_lowercase(); + if should_forward_header(&name_str) { + // The original name and value are already valid, so we can just clone them + headers.insert(name.clone(), value.clone()); + } + } + + headers +} + +/// Determine if a header should be forwarded from backend to client +fn should_forward_header(name: &str) -> bool { + // List of headers that should NOT be forwarded (hop-by-hop headers) + !matches!( + name, + "connection" | + "keep-alive" | + "proxy-authenticate" | + "proxy-authorization" | + "te" | + "trailers" | + "transfer-encoding" | + "upgrade" | + "content-encoding" | // Let axum/hyper handle encoding + "host" // Should not forward the backend's host header + ) +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 3b3137423..bfcb5ad2e 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -12,6 +12,7 @@ use std::fmt::Debug; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; pub mod factory; +pub mod header_utils; pub mod pd_router; pub mod pd_types; pub mod router; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index a15048a35..1d0dd3a60 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,5 +1,6 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems +use super::header_utils; use super::pd_types::{api_path, PDRouterError}; use crate::config::types::{ CircuitBreakerConfig as ConfigCircuitBreakerConfig, @@ -170,17 +171,26 @@ impl PDRouter { } match request_builder.send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(body) => (StatusCode::OK, body).into_response(), - Err(e) => { - error!("Failed to read response body: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response() + Ok(res) if res.status().is_success() => { + let response_headers = header_utils::preserve_response_headers(res.headers()); + + match res.bytes().await { + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = StatusCode::OK; + *response.headers_mut() = response_headers; + response + } + Err(e) => { + error!("Failed to read response body: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response() + } } - }, + } Ok(res) => { let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); @@ -822,12 +832,16 @@ impl PDRouter { json.pointer("/meta_info/input_token_logprobs").cloned() }); + let response_headers = + header_utils::preserve_response_headers(res.headers()); + Self::create_streaming_response( res.bytes_stream(), status, prefill_logprobs, return_logprob, None, + Some(response_headers), ) } else { // Non-streaming response with logprobs @@ -918,17 +932,30 @@ impl PDRouter { } else if is_stream { // Streaming response without logprobs - direct passthrough let decode_url = decode.url().to_string(); + let response_headers = + header_utils::preserve_response_headers(res.headers()); + Self::create_streaming_response( res.bytes_stream(), status, None, false, Some(decode_url), + Some(response_headers), ) } else { // Non-streaming response without logprobs - direct passthrough like fast version + let response_headers = + header_utils::preserve_response_headers(res.headers()); + match res.bytes().await { - Ok(decode_body) => (status, decode_body).into_response(), + Ok(decode_body) => { + let mut response = + Response::new(axum::body::Body::from(decode_body)); + *response.status_mut() = status; + *response.headers_mut() = response_headers; + response + } Err(e) => { error!("Failed to read decode response: {}", e); (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response") @@ -1081,6 +1108,7 @@ impl PDRouter { prefill_logprobs: Option, return_logprob: bool, decode_url: Option, + headers: Option, ) -> Response { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); @@ -1118,9 +1146,12 @@ impl PDRouter { let mut response = Response::new(body); *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + + // Use provided headers or create new ones, then ensure content-type is set for streaming + let mut headers = headers.unwrap_or_else(HeaderMap::new); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + *response.headers_mut() = headers; + response } @@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter { async fn get_models(&self, req: Request) -> Response { // Extract headers first to avoid Send issues - let headers = crate::routers::router::copy_request_headers(&req); + let headers = header_utils::copy_request_headers(&req); // Proxy to first prefill worker self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers)) @@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter { async fn get_model_info(&self, req: Request) -> Response { // Extract headers first to avoid Send issues - let headers = crate::routers::router::copy_request_headers(&req); + let headers = header_utils::copy_request_headers(&req); // Proxy to first prefill worker self.proxy_to_first_worker( diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index ca3210a63..36123a37c 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,3 +1,4 @@ +use super::header_utils; use crate::config::types::{ CircuitBreakerConfig as ConfigCircuitBreakerConfig, HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, @@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; -pub fn copy_request_headers(req: &Request) -> Vec<(String, String)> { - req.headers() - .iter() - .filter_map(|(name, value)| { - value - .to_str() - .ok() - .map(|v| (name.to_string(), v.to_string())) - }) - .collect() -} /// Regular router that uses injected load balancing policies #[derive(Debug)] @@ -400,7 +390,7 @@ impl Router { // Helper method to proxy GET requests to the first available worker async fn proxy_get_request(&self, req: Request, endpoint: &str) -> Response { - let headers = copy_request_headers(&req); + let headers = super::header_utils::copy_request_headers(&req); match self.select_first_worker() { Ok(worker_url) => { @@ -416,8 +406,18 @@ impl Router { Ok(res) => { let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + // Preserve headers from backend + let response_headers = + header_utils::preserve_response_headers(res.headers()); + match res.bytes().await { - Ok(body) => (status, body).into_response(), + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = status; + *response.headers_mut() = response_headers; + response + } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read response: {}", e), @@ -645,9 +645,16 @@ impl Router { .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); if !is_stream { - // For non-streaming requests, get response first + // For non-streaming requests, preserve headers + let response_headers = super::header_utils::preserve_response_headers(res.headers()); + let response = match res.bytes().await { - Ok(body) => (status, body).into_response(), + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = status; + *response.headers_mut() = response_headers; + response + } Err(e) => { let error_msg = format!("Failed to get response body: {}", e); (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response() @@ -670,6 +677,11 @@ impl Router { let workers = Arc::clone(&self.workers); let worker_url = worker_url.to_string(); + // Preserve headers for streaming response + let mut response_headers = header_utils::preserve_response_headers(res.headers()); + // Ensure we set the correct content-type for SSE + response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + let stream = res.bytes_stream(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); @@ -724,12 +736,15 @@ impl Router { let mut response = Response::new(body); *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + *response.headers_mut() = response_headers; response } else { // For requests without load tracking, just stream + // Preserve headers for streaming response + let mut response_headers = header_utils::preserve_response_headers(res.headers()); + // Ensure we set the correct content-type for SSE + response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + let stream = res.bytes_stream(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); @@ -756,9 +771,7 @@ impl Router { let mut response = Response::new(body); *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + *response.headers_mut() = response_headers; response } }