[router] preserve original worker response header in router (#9236)

This commit is contained in:
Simo Lin
2025-08-15 11:01:47 -07:00
committed by GitHub
parent d7e38b2f6d
commit da53e13cbb
4 changed files with 135 additions and 37 deletions

View File

@@ -1,5 +1,6 @@
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use super::header_utils;
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
@@ -170,17 +171,26 @@ impl PDRouter {
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
Ok(res) if res.status().is_success() => {
let response_headers = header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = StatusCode::OK;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
},
}
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
@@ -822,12 +832,16 @@ impl PDRouter {
json.pointer("/meta_info/input_token_logprobs").cloned()
});
let response_headers =
header_utils::preserve_response_headers(res.headers());
Self::create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
None,
Some(response_headers),
)
} else {
// Non-streaming response with logprobs
@@ -918,17 +932,30 @@ impl PDRouter {
} else if is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
let response_headers =
header_utils::preserve_response_headers(res.headers());
Self::create_streaming_response(
res.bytes_stream(),
status,
None,
false,
Some(decode_url),
Some(response_headers),
)
} else {
// Non-streaming response without logprobs - direct passthrough like fast version
let response_headers =
header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(decode_body) => (status, decode_body).into_response(),
Ok(decode_body) => {
let mut response =
Response::new(axum::body::Body::from(decode_body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
@@ -1081,6 +1108,7 @@ impl PDRouter {
prefill_logprobs: Option<Value>,
return_logprob: bool,
decode_url: Option<String>,
headers: Option<HeaderMap>,
) -> Response {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -1118,9 +1146,12 @@ impl PDRouter {
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
// Use provided headers or create new ones, then ensure content-type is set for streaming
let mut headers = headers.unwrap_or_else(HeaderMap::new);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = headers;
response
}
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(