[router] preserve original worker response header in router (#9236)
This commit is contained in:
53
sgl-router/src/routers/header_utils.rs
Normal file
53
sgl-router/src/routers/header_utils.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
use axum::body::Body;
|
||||||
|
use axum::extract::Request;
|
||||||
|
use axum::http::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
|
||||||
|
/// Copy request headers to a Vec of name-value string pairs
|
||||||
|
/// Used for forwarding headers to backend workers
|
||||||
|
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||||
|
req.headers()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(name, value)| {
|
||||||
|
// Convert header value to string, skipping non-UTF8 headers
|
||||||
|
value
|
||||||
|
.to_str()
|
||||||
|
.ok()
|
||||||
|
.map(|v| (name.to_string(), v.to_string()))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert headers from reqwest Response to axum HeaderMap
|
||||||
|
/// Filters out hop-by-hop headers that shouldn't be forwarded
|
||||||
|
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
|
||||||
|
for (name, value) in reqwest_headers.iter() {
|
||||||
|
// Skip hop-by-hop headers that shouldn't be forwarded
|
||||||
|
let name_str = name.as_str().to_lowercase();
|
||||||
|
if should_forward_header(&name_str) {
|
||||||
|
// The original name and value are already valid, so we can just clone them
|
||||||
|
headers.insert(name.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine if a header should be forwarded from backend to client
|
||||||
|
fn should_forward_header(name: &str) -> bool {
|
||||||
|
// List of headers that should NOT be forwarded (hop-by-hop headers)
|
||||||
|
!matches!(
|
||||||
|
name,
|
||||||
|
"connection" |
|
||||||
|
"keep-alive" |
|
||||||
|
"proxy-authenticate" |
|
||||||
|
"proxy-authorization" |
|
||||||
|
"te" |
|
||||||
|
"trailers" |
|
||||||
|
"transfer-encoding" |
|
||||||
|
"upgrade" |
|
||||||
|
"content-encoding" | // Let axum/hyper handle encoding
|
||||||
|
"host" // Should not forward the backend's host header
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ use std::fmt::Debug;
|
|||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod header_utils;
|
||||||
pub mod pd_router;
|
pub mod pd_router;
|
||||||
pub mod pd_types;
|
pub mod pd_types;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// PD (Prefill-Decode) Router Implementation
|
// PD (Prefill-Decode) Router Implementation
|
||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
|
use super::header_utils;
|
||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::{
|
use crate::config::types::{
|
||||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
@@ -170,17 +171,26 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
match request_builder.send().await {
|
match request_builder.send().await {
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
Ok(res) if res.status().is_success() => {
|
||||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to read response body: {}", e);
|
match res.bytes().await {
|
||||||
(
|
Ok(body) => {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
format!("Failed to read response body: {}", e),
|
*response.status_mut() = StatusCode::OK;
|
||||||
)
|
*response.headers_mut() = response_headers;
|
||||||
.into_response()
|
response
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to read response body: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response body: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
@@ -822,12 +832,16 @@ impl PDRouter {
|
|||||||
json.pointer("/meta_info/input_token_logprobs").cloned()
|
json.pointer("/meta_info/input_token_logprobs").cloned()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let response_headers =
|
||||||
|
header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
Self::create_streaming_response(
|
Self::create_streaming_response(
|
||||||
res.bytes_stream(),
|
res.bytes_stream(),
|
||||||
status,
|
status,
|
||||||
prefill_logprobs,
|
prefill_logprobs,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
None,
|
None,
|
||||||
|
Some(response_headers),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming response with logprobs
|
// Non-streaming response with logprobs
|
||||||
@@ -918,17 +932,30 @@ impl PDRouter {
|
|||||||
} else if is_stream {
|
} else if is_stream {
|
||||||
// Streaming response without logprobs - direct passthrough
|
// Streaming response without logprobs - direct passthrough
|
||||||
let decode_url = decode.url().to_string();
|
let decode_url = decode.url().to_string();
|
||||||
|
let response_headers =
|
||||||
|
header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
Self::create_streaming_response(
|
Self::create_streaming_response(
|
||||||
res.bytes_stream(),
|
res.bytes_stream(),
|
||||||
status,
|
status,
|
||||||
None,
|
None,
|
||||||
false,
|
false,
|
||||||
Some(decode_url),
|
Some(decode_url),
|
||||||
|
Some(response_headers),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming response without logprobs - direct passthrough like fast version
|
// Non-streaming response without logprobs - direct passthrough like fast version
|
||||||
|
let response_headers =
|
||||||
|
header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(decode_body) => (status, decode_body).into_response(),
|
Ok(decode_body) => {
|
||||||
|
let mut response =
|
||||||
|
Response::new(axum::body::Body::from(decode_body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
*response.headers_mut() = response_headers;
|
||||||
|
response
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to read decode response: {}", e);
|
error!("Failed to read decode response: {}", e);
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
||||||
@@ -1081,6 +1108,7 @@ impl PDRouter {
|
|||||||
prefill_logprobs: Option<Value>,
|
prefill_logprobs: Option<Value>,
|
||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
decode_url: Option<String>,
|
decode_url: Option<String>,
|
||||||
|
headers: Option<HeaderMap>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
@@ -1118,9 +1146,12 @@ impl PDRouter {
|
|||||||
|
|
||||||
let mut response = Response::new(body);
|
let mut response = Response::new(body);
|
||||||
*response.status_mut() = status;
|
*response.status_mut() = status;
|
||||||
response
|
|
||||||
.headers_mut()
|
// Use provided headers or create new ones, then ensure content-type is set for streaming
|
||||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
let mut headers = headers.unwrap_or_else(HeaderMap::new);
|
||||||
|
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
*response.headers_mut() = headers;
|
||||||
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
|
|||||||
|
|
||||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = header_utils::copy_request_headers(&req);
|
||||||
|
|
||||||
// Proxy to first prefill worker
|
// Proxy to first prefill worker
|
||||||
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
|
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
|
||||||
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
|
|||||||
|
|
||||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = header_utils::copy_request_headers(&req);
|
||||||
|
|
||||||
// Proxy to first prefill worker
|
// Proxy to first prefill worker
|
||||||
self.proxy_to_first_worker(
|
self.proxy_to_first_worker(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use super::header_utils;
|
||||||
use crate::config::types::{
|
use crate::config::types::{
|
||||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||||
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
|
||||||
req.headers()
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(name, value)| {
|
|
||||||
value
|
|
||||||
.to_str()
|
|
||||||
.ok()
|
|
||||||
.map(|v| (name.to_string(), v.to_string()))
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Regular router that uses injected load balancing policies
|
/// Regular router that uses injected load balancing policies
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -400,7 +390,7 @@ impl Router {
|
|||||||
|
|
||||||
// Helper method to proxy GET requests to the first available worker
|
// Helper method to proxy GET requests to the first available worker
|
||||||
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||||
let headers = copy_request_headers(&req);
|
let headers = super::header_utils::copy_request_headers(&req);
|
||||||
|
|
||||||
match self.select_first_worker() {
|
match self.select_first_worker() {
|
||||||
Ok(worker_url) => {
|
Ok(worker_url) => {
|
||||||
@@ -416,8 +406,18 @@ impl Router {
|
|||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
// Preserve headers from backend
|
||||||
|
let response_headers =
|
||||||
|
header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(body) => (status, body).into_response(),
|
Ok(body) => {
|
||||||
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
*response.headers_mut() = response_headers;
|
||||||
|
response
|
||||||
|
}
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
format!("Failed to read response: {}", e),
|
format!("Failed to read response: {}", e),
|
||||||
@@ -645,9 +645,16 @@ impl Router {
|
|||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
if !is_stream {
|
if !is_stream {
|
||||||
// For non-streaming requests, get response first
|
// For non-streaming requests, preserve headers
|
||||||
|
let response_headers = super::header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
let response = match res.bytes().await {
|
let response = match res.bytes().await {
|
||||||
Ok(body) => (status, body).into_response(),
|
Ok(body) => {
|
||||||
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
*response.headers_mut() = response_headers;
|
||||||
|
response
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let error_msg = format!("Failed to get response body: {}", e);
|
let error_msg = format!("Failed to get response body: {}", e);
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
||||||
@@ -670,6 +677,11 @@ impl Router {
|
|||||||
let workers = Arc::clone(&self.workers);
|
let workers = Arc::clone(&self.workers);
|
||||||
let worker_url = worker_url.to_string();
|
let worker_url = worker_url.to_string();
|
||||||
|
|
||||||
|
// Preserve headers for streaming response
|
||||||
|
let mut response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
|
// Ensure we set the correct content-type for SSE
|
||||||
|
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
|
||||||
let stream = res.bytes_stream();
|
let stream = res.bytes_stream();
|
||||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
@@ -724,12 +736,15 @@ impl Router {
|
|||||||
|
|
||||||
let mut response = Response::new(body);
|
let mut response = Response::new(body);
|
||||||
*response.status_mut() = status;
|
*response.status_mut() = status;
|
||||||
response
|
*response.headers_mut() = response_headers;
|
||||||
.headers_mut()
|
|
||||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
|
||||||
response
|
response
|
||||||
} else {
|
} else {
|
||||||
// For requests without load tracking, just stream
|
// For requests without load tracking, just stream
|
||||||
|
// Preserve headers for streaming response
|
||||||
|
let mut response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
|
// Ensure we set the correct content-type for SSE
|
||||||
|
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
|
||||||
let stream = res.bytes_stream();
|
let stream = res.bytes_stream();
|
||||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
@@ -756,9 +771,7 @@ impl Router {
|
|||||||
|
|
||||||
let mut response = Response::new(body);
|
let mut response = Response::new(body);
|
||||||
*response.status_mut() = status;
|
*response.status_mut() = status;
|
||||||
response
|
*response.headers_mut() = response_headers;
|
||||||
.headers_mut()
|
|
||||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user