[router] preserve original worker response header in router (#9236)

This commit is contained in:
Simo Lin
2025-08-15 11:01:47 -07:00
committed by GitHub
parent d7e38b2f6d
commit da53e13cbb
4 changed files with 135 additions and 37 deletions

View File

@@ -1,3 +1,4 @@
use super::header_utils;
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
@@ -400,7 +390,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = copy_request_headers(&req);
let headers = super::header_utils::copy_request_headers(&req);
match self.select_first_worker() {
Ok(worker_url) => {
@@ -416,8 +406,18 @@ impl Router {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Preserve headers from backend
let response_headers =
header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(body) => (status, body).into_response(),
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
@@ -645,9 +645,16 @@ impl Router {
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// For non-streaming requests, get response first
// For non-streaming requests, preserve headers
let response_headers = super::header_utils::preserve_response_headers(res.headers());
let response = match res.bytes().await {
Ok(body) => (status, body).into_response(),
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
@@ -670,6 +677,11 @@ impl Router {
let workers = Arc::clone(&self.workers);
let worker_url = worker_url.to_string();
// Preserve headers for streaming response
let mut response_headers = header_utils::preserve_response_headers(res.headers());
// Ensure we set the correct content-type for SSE
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -724,12 +736,15 @@ impl Router {
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = response_headers;
response
} else {
// For requests without load tracking, just stream
// Preserve headers for streaming response
let mut response_headers = header_utils::preserve_response_headers(res.headers());
// Ensure we set the correct content-type for SSE
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
@@ -756,9 +771,7 @@ impl Router {
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = response_headers;
response
}
}