[router] preserve original worker response header in router (#9236)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// PD (Prefill-Decode) Router Implementation
|
||||
// This module handles routing for disaggregated prefill-decode systems
|
||||
use super::header_utils;
|
||||
use super::pd_types::{api_path, PDRouterError};
|
||||
use crate::config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||
@@ -170,17 +171,26 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
Ok(res) if res.status().is_success() => {
|
||||
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
@@ -822,12 +832,16 @@ impl PDRouter {
|
||||
json.pointer("/meta_info/input_token_logprobs").cloned()
|
||||
});
|
||||
|
||||
let response_headers =
|
||||
header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
prefill_logprobs,
|
||||
return_logprob,
|
||||
None,
|
||||
Some(response_headers),
|
||||
)
|
||||
} else {
|
||||
// Non-streaming response with logprobs
|
||||
@@ -918,17 +932,30 @@ impl PDRouter {
|
||||
} else if is_stream {
|
||||
// Streaming response without logprobs - direct passthrough
|
||||
let decode_url = decode.url().to_string();
|
||||
let response_headers =
|
||||
header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
None,
|
||||
false,
|
||||
Some(decode_url),
|
||||
Some(response_headers),
|
||||
)
|
||||
} else {
|
||||
// Non-streaming response without logprobs - direct passthrough like fast version
|
||||
let response_headers =
|
||||
header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(decode_body) => (status, decode_body).into_response(),
|
||||
Ok(decode_body) => {
|
||||
let mut response =
|
||||
Response::new(axum::body::Body::from(decode_body));
|
||||
*response.status_mut() = status;
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read decode response: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
||||
@@ -1081,6 +1108,7 @@ impl PDRouter {
|
||||
prefill_logprobs: Option<Value>,
|
||||
return_logprob: bool,
|
||||
decode_url: Option<String>,
|
||||
headers: Option<HeaderMap>,
|
||||
) -> Response {
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
@@ -1118,9 +1146,12 @@ impl PDRouter {
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
|
||||
// Use provided headers or create new ones, then ensure content-type is set for streaming
|
||||
let mut headers = headers.unwrap_or_else(HeaderMap::new);
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
*response.headers_mut() = headers;
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
let headers = header_utils::copy_request_headers(&req);
|
||||
|
||||
// Proxy to first prefill worker
|
||||
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
|
||||
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
|
||||
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
let headers = header_utils::copy_request_headers(&req);
|
||||
|
||||
// Proxy to first prefill worker
|
||||
self.proxy_to_first_worker(
|
||||
|
||||
Reference in New Issue
Block a user