[router] preserve original worker response header in router (#9236)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use super::header_utils;
|
||||
use crate::config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.map(|v| (name.to_string(), v.to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Regular router that uses injected load balancing policies
|
||||
#[derive(Debug)]
|
||||
@@ -400,7 +390,7 @@ impl Router {
|
||||
|
||||
// Helper method to proxy GET requests to the first available worker
|
||||
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||
let headers = copy_request_headers(&req);
|
||||
let headers = super::header_utils::copy_request_headers(&req);
|
||||
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => {
|
||||
@@ -416,8 +406,18 @@ impl Router {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// Preserve headers from backend
|
||||
let response_headers =
|
||||
header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
@@ -645,9 +645,16 @@ impl Router {
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !is_stream {
|
||||
// For non-streaming requests, get response first
|
||||
// For non-streaming requests, preserve headers
|
||||
let response_headers = super::header_utils::preserve_response_headers(res.headers());
|
||||
|
||||
let response = match res.bytes().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to get response body: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
||||
@@ -670,6 +677,11 @@ impl Router {
|
||||
let workers = Arc::clone(&self.workers);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
// Preserve headers for streaming response
|
||||
let mut response_headers = header_utils::preserve_response_headers(res.headers());
|
||||
// Ensure we set the correct content-type for SSE
|
||||
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
|
||||
let stream = res.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
@@ -724,12 +736,15 @@ impl Router {
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
} else {
|
||||
// For requests without load tracking, just stream
|
||||
// Preserve headers for streaming response
|
||||
let mut response_headers = header_utils::preserve_response_headers(res.headers());
|
||||
// Ensure we set the correct content-type for SSE
|
||||
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
|
||||
let stream = res.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
@@ -756,9 +771,7 @@ impl Router {
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
*response.headers_mut() = response_headers;
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user