diff --git a/sgl-router/src/routers/header_utils.rs b/sgl-router/src/routers/header_utils.rs
new file mode 100644
index 000000000..92ce5d3b6
--- /dev/null
+++ b/sgl-router/src/routers/header_utils.rs
@@ -0,0 +1,53 @@
+use axum::body::Body;
+use axum::extract::Request;
+use axum::http::{HeaderMap, HeaderName, HeaderValue};
+
+/// Copy request headers to a Vec of name-value string pairs
+/// Used for forwarding headers to backend workers
+pub fn copy_request_headers(req: &Request
) -> Vec<(String, String)> {
+ req.headers()
+ .iter()
+ .filter_map(|(name, value)| {
+ // Convert header value to string, skipping non-UTF8 headers
+ value
+ .to_str()
+ .ok()
+ .map(|v| (name.to_string(), v.to_string()))
+ })
+ .collect()
+}
+
+/// Convert headers from reqwest Response to axum HeaderMap
+/// Filters out hop-by-hop headers that shouldn't be forwarded
+pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
+ let mut headers = HeaderMap::new();
+
+ for (name, value) in reqwest_headers.iter() {
+ // Skip hop-by-hop headers that shouldn't be forwarded
+ let name_str = name.as_str().to_lowercase();
+ if should_forward_header(&name_str) {
+ // The original name and value are already valid, so we can just clone them
+ headers.insert(name.clone(), value.clone());
+ }
+ }
+
+ headers
+}
+
+/// Determine if a header should be forwarded from backend to client
+fn should_forward_header(name: &str) -> bool {
+ // List of headers that should NOT be forwarded (hop-by-hop headers)
+ !matches!(
+ name,
+ "connection" |
+ "keep-alive" |
+ "proxy-authenticate" |
+ "proxy-authorization" |
+ "te" |
+ "trailers" |
+ "transfer-encoding" |
+ "upgrade" |
+ "content-encoding" | // Let axum/hyper handle encoding
+ "host" // Should not forward the backend's host header
+ )
+}
diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs
index 3b3137423..bfcb5ad2e 100644
--- a/sgl-router/src/routers/mod.rs
+++ b/sgl-router/src/routers/mod.rs
@@ -12,6 +12,7 @@ use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
+pub mod header_utils;
pub mod pd_router;
pub mod pd_types;
pub mod router;
diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs
index a15048a35..1d0dd3a60 100644
--- a/sgl-router/src/routers/pd_router.rs
+++ b/sgl-router/src/routers/pd_router.rs
@@ -1,5 +1,6 @@
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
+use super::header_utils;
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
@@ -170,17 +171,26 @@ impl PDRouter {
}
match request_builder.send().await {
- Ok(res) if res.status().is_success() => match res.bytes().await {
- Ok(body) => (StatusCode::OK, body).into_response(),
- Err(e) => {
- error!("Failed to read response body: {}", e);
- (
- StatusCode::INTERNAL_SERVER_ERROR,
- format!("Failed to read response body: {}", e),
- )
- .into_response()
+ Ok(res) if res.status().is_success() => {
+ let response_headers = header_utils::preserve_response_headers(res.headers());
+
+ match res.bytes().await {
+ Ok(body) => {
+ let mut response = Response::new(axum::body::Body::from(body));
+ *response.status_mut() = StatusCode::OK;
+ *response.headers_mut() = response_headers;
+ response
+ }
+ Err(e) => {
+ error!("Failed to read response body: {}", e);
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ format!("Failed to read response body: {}", e),
+ )
+ .into_response()
+ }
}
- },
+ }
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
@@ -822,12 +832,16 @@ impl PDRouter {
json.pointer("/meta_info/input_token_logprobs").cloned()
});
+ let response_headers =
+ header_utils::preserve_response_headers(res.headers());
+
Self::create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
None,
+ Some(response_headers),
)
} else {
// Non-streaming response with logprobs
@@ -918,17 +932,30 @@ impl PDRouter {
} else if is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
+ let response_headers =
+ header_utils::preserve_response_headers(res.headers());
+
Self::create_streaming_response(
res.bytes_stream(),
status,
None,
false,
Some(decode_url),
+ Some(response_headers),
)
} else {
// Non-streaming response without logprobs - direct passthrough like fast version
+ let response_headers =
+ header_utils::preserve_response_headers(res.headers());
+
match res.bytes().await {
- Ok(decode_body) => (status, decode_body).into_response(),
+ Ok(decode_body) => {
+ let mut response =
+ Response::new(axum::body::Body::from(decode_body));
+ *response.status_mut() = status;
+ *response.headers_mut() = response_headers;
+ response
+ }
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
@@ -1081,6 +1108,7 @@ impl PDRouter {
prefill_logprobs: Option,
return_logprob: bool,
decode_url: Option,
+ headers: Option,
) -> Response {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -1118,9 +1146,12 @@ impl PDRouter {
let mut response = Response::new(body);
*response.status_mut() = status;
- response
- .headers_mut()
- .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+
+ // Use provided headers or create new ones, then ensure content-type is set for streaming
+ let mut headers = headers.unwrap_or_else(HeaderMap::new);
+ headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+ *response.headers_mut() = headers;
+
response
}
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
async fn get_models(&self, req: Request) -> Response {
// Extract headers first to avoid Send issues
- let headers = crate::routers::router::copy_request_headers(&req);
+ let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
async fn get_model_info(&self, req: Request) -> Response {
// Extract headers first to avoid Send issues
- let headers = crate::routers::router::copy_request_headers(&req);
+ let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(
diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs
index ca3210a63..36123a37c 100644
--- a/sgl-router/src/routers/router.rs
+++ b/sgl-router/src/routers/router.rs
@@ -1,3 +1,4 @@
+use super::header_utils;
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
-pub fn copy_request_headers(req: &Request) -> Vec<(String, String)> {
- req.headers()
- .iter()
- .filter_map(|(name, value)| {
- value
- .to_str()
- .ok()
- .map(|v| (name.to_string(), v.to_string()))
- })
- .collect()
-}
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
@@ -400,7 +390,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request, endpoint: &str) -> Response {
- let headers = copy_request_headers(&req);
+ let headers = super::header_utils::copy_request_headers(&req);
match self.select_first_worker() {
Ok(worker_url) => {
@@ -416,8 +406,18 @@ impl Router {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
+
+ // Preserve headers from backend
+ let response_headers =
+ header_utils::preserve_response_headers(res.headers());
+
match res.bytes().await {
- Ok(body) => (status, body).into_response(),
+ Ok(body) => {
+ let mut response = Response::new(axum::body::Body::from(body));
+ *response.status_mut() = status;
+ *response.headers_mut() = response_headers;
+ response
+ }
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
@@ -645,9 +645,16 @@ impl Router {
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
- // For non-streaming requests, get response first
+ // For non-streaming requests, preserve headers
+ let response_headers = super::header_utils::preserve_response_headers(res.headers());
+
let response = match res.bytes().await {
- Ok(body) => (status, body).into_response(),
+ Ok(body) => {
+ let mut response = Response::new(axum::body::Body::from(body));
+ *response.status_mut() = status;
+ *response.headers_mut() = response_headers;
+ response
+ }
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
@@ -670,6 +677,11 @@ impl Router {
let workers = Arc::clone(&self.workers);
let worker_url = worker_url.to_string();
+ // Preserve headers for streaming response
+ let mut response_headers = header_utils::preserve_response_headers(res.headers());
+ // Ensure we set the correct content-type for SSE
+ response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -724,12 +736,15 @@ impl Router {
let mut response = Response::new(body);
*response.status_mut() = status;
- response
- .headers_mut()
- .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+ *response.headers_mut() = response_headers;
response
} else {
// For requests without load tracking, just stream
+ // Preserve headers for streaming response
+ let mut response_headers = header_utils::preserve_response_headers(res.headers());
+ // Ensure we set the correct content-type for SSE
+ response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -756,9 +771,7 @@ impl Router {
let mut response = Response::new(body);
*response.status_mut() = status;
- response
- .headers_mut()
- .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
+ *response.headers_mut() = response_headers;
response
}
}