[router] dedicated prefill HTTP client and request-path optimizations (#8923)
This commit is contained in:
@@ -38,6 +38,8 @@ pub struct PDRouter {
|
|||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
pub client: Client,
|
pub client: Client,
|
||||||
|
// Dedicated client for prefill fire-and-forget (non-logprob) requests
|
||||||
|
pub prefill_client: Client,
|
||||||
pub retry_config: RetryConfig,
|
pub retry_config: RetryConfig,
|
||||||
_prefill_health_checker: Option<HealthChecker>,
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
@@ -255,6 +257,15 @@ impl PDRouter {
|
|||||||
let decode_health_checker =
|
let decode_health_checker =
|
||||||
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||||
|
|
||||||
|
// Build a dedicated prefill client for fire-and-forget semantics
|
||||||
|
let prefill_client = reqwest::Client::builder()
|
||||||
|
.pool_max_idle_per_host(0)
|
||||||
|
.http1_only()
|
||||||
|
.connect_timeout(Duration::from_millis(300))
|
||||||
|
.timeout(Duration::from_secs(2))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
|
||||||
|
|
||||||
Ok(PDRouter {
|
Ok(PDRouter {
|
||||||
prefill_workers,
|
prefill_workers,
|
||||||
decode_workers,
|
decode_workers,
|
||||||
@@ -267,6 +278,7 @@ impl PDRouter {
|
|||||||
worker_loads,
|
worker_loads,
|
||||||
load_monitor_handle,
|
load_monitor_handle,
|
||||||
client,
|
client,
|
||||||
|
prefill_client,
|
||||||
retry_config,
|
retry_config,
|
||||||
_prefill_health_checker: Some(prefill_health_checker),
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
_decode_health_checker: Some(decode_health_checker),
|
_decode_health_checker: Some(decode_health_checker),
|
||||||
@@ -365,41 +377,69 @@ impl PDRouter {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create request with bootstrap fields
|
// Helper to inject bootstrap fields into an existing JSON request value
|
||||||
fn create_request_with_bootstrap<T: serde::Serialize>(
|
fn inject_bootstrap_into_value(
|
||||||
request: &T,
|
mut original: Value,
|
||||||
prefill_worker: &dyn Worker,
|
prefill_worker: &dyn Worker,
|
||||||
batch_size: Option<usize>,
|
batch_size: Option<usize>,
|
||||||
) -> Result<serde_json::Value, serde_json::Error> {
|
) -> Result<Value, String> {
|
||||||
// Get bootstrap port from prefill worker
|
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let hostname = super::pd_types::get_hostname(prefill_worker.url());
|
let hostname = super::pd_types::get_hostname(prefill_worker.url());
|
||||||
|
|
||||||
// Create optimized request with bootstrap fields
|
let obj = original
|
||||||
if let Some(batch_size) = batch_size {
|
.as_object_mut()
|
||||||
// Batch request
|
.ok_or_else(|| "Request must be a JSON object".to_string())?;
|
||||||
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
|
|
||||||
original: request,
|
if let Some(n) = batch_size {
|
||||||
bootstrap_host: vec![hostname; batch_size],
|
let mut hosts = Vec::with_capacity(n);
|
||||||
bootstrap_port: vec![bootstrap_port; batch_size],
|
let mut ports = Vec::with_capacity(n);
|
||||||
bootstrap_room: (0..batch_size)
|
let mut rooms = Vec::with_capacity(n);
|
||||||
.map(|_| super::pd_types::generate_room_id())
|
for _ in 0..n {
|
||||||
.collect(),
|
hosts.push(hostname.clone());
|
||||||
};
|
ports.push(bootstrap_port);
|
||||||
serde_json::to_value(&request_with_bootstrap)
|
rooms.push(super::pd_types::generate_room_id());
|
||||||
|
}
|
||||||
|
obj.insert(
|
||||||
|
"bootstrap_host".to_string(),
|
||||||
|
Value::Array(hosts.into_iter().map(serde_json::Value::from).collect()),
|
||||||
|
);
|
||||||
|
obj.insert(
|
||||||
|
"bootstrap_port".to_string(),
|
||||||
|
Value::Array(
|
||||||
|
ports
|
||||||
|
.into_iter()
|
||||||
|
.map(|p| match p {
|
||||||
|
Some(v) => serde_json::Value::from(v),
|
||||||
|
None => Value::Null,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
obj.insert(
|
||||||
|
"bootstrap_room".to_string(),
|
||||||
|
Value::Array(rooms.into_iter().map(serde_json::Value::from).collect()),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
// Single request
|
obj.insert(
|
||||||
let request_with_bootstrap = super::pd_types::RequestWithBootstrap {
|
"bootstrap_host".to_string(),
|
||||||
original: request,
|
serde_json::Value::from(hostname),
|
||||||
bootstrap_host: hostname,
|
);
|
||||||
bootstrap_port,
|
obj.insert(
|
||||||
bootstrap_room: super::pd_types::generate_room_id(),
|
"bootstrap_port".to_string(),
|
||||||
};
|
match bootstrap_port {
|
||||||
serde_json::to_value(&request_with_bootstrap)
|
Some(v) => serde_json::Value::from(v),
|
||||||
|
None => Value::Null,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
obj.insert(
|
||||||
|
"bootstrap_room".to_string(),
|
||||||
|
serde_json::Value::from(super::pd_types::generate_room_id()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
Ok(original)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the dual dispatch to prefill and decode servers
|
// Execute the dual dispatch to prefill and decode servers
|
||||||
@@ -417,12 +457,15 @@ impl PDRouter {
|
|||||||
// Update load tracking for both workers
|
// Update load tracking for both workers
|
||||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||||
|
|
||||||
// Build requests with headers
|
// Build decode request with shared client
|
||||||
let prefill_request =
|
let decode_request = self.build_post_with_headers(
|
||||||
self.build_request_with_headers(prefill.url(), route, &json_request, headers);
|
&self.client,
|
||||||
|
decode.url(),
|
||||||
let decode_request =
|
route,
|
||||||
self.build_request_with_headers(decode.url(), route, &json_request, headers);
|
&json_request,
|
||||||
|
headers,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
// Send both requests concurrently
|
// Send both requests concurrently
|
||||||
debug!(
|
debug!(
|
||||||
@@ -432,6 +475,15 @@ impl PDRouter {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if return_logprob {
|
if return_logprob {
|
||||||
|
// Build prefill request with shared client when we need response body
|
||||||
|
let prefill_request = self.build_post_with_headers(
|
||||||
|
&self.client,
|
||||||
|
prefill.url(),
|
||||||
|
route,
|
||||||
|
&json_request,
|
||||||
|
headers,
|
||||||
|
false,
|
||||||
|
);
|
||||||
// When we need logprobs, wait for both responses
|
// When we need logprobs, wait for both responses
|
||||||
let (prefill_result, decode_result) =
|
let (prefill_result, decode_result) =
|
||||||
tokio::join!(prefill_request.send(), decode_request.send());
|
tokio::join!(prefill_request.send(), decode_request.send());
|
||||||
@@ -525,19 +577,27 @@ impl PDRouter {
|
|||||||
} else {
|
} else {
|
||||||
// When we don't need logprobs, only wait for decode response
|
// When we don't need logprobs, only wait for decode response
|
||||||
// Send both requests concurrently but don't wait for prefill
|
// Send both requests concurrently but don't wait for prefill
|
||||||
// Add headers to minimize response size when we don't need the body
|
// Use dedicated prefill client with Connection: close
|
||||||
let prefill_future = prefill_request.header("Connection", "close").send();
|
let prefill_future = self
|
||||||
|
.build_post_with_headers(
|
||||||
|
&self.prefill_client,
|
||||||
|
prefill.url(),
|
||||||
|
route,
|
||||||
|
&json_request,
|
||||||
|
headers,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.send();
|
||||||
let decode_future = decode_request.send();
|
let decode_future = decode_request.send();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Ok(response) = prefill_future.await {
|
if let Ok(response) = prefill_future.await {
|
||||||
// Consume with a short timeout to free connection quickly
|
// Consume at most one small chunk with a very short timeout to advance flow control
|
||||||
let consume_future = async {
|
let _ = tokio::time::timeout(Duration::from_millis(20), async {
|
||||||
let _ = response.bytes().await;
|
let mut s = response.bytes_stream();
|
||||||
};
|
let _ = s.next().await;
|
||||||
|
})
|
||||||
// Give it 100ms to consume, then abandon
|
.await;
|
||||||
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -879,29 +939,34 @@ impl PDRouter {
|
|||||||
Ok((prefill_status, prefill_body))
|
Ok((prefill_status, prefill_body))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to build a request with headers copied from the original request
|
fn build_post_with_headers(
|
||||||
fn build_request_with_headers(
|
|
||||||
&self,
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
url: &str,
|
url: &str,
|
||||||
route: &str,
|
route: &str,
|
||||||
json_request: &Value,
|
json_request: &serde_json::Value,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
|
connection_close: bool,
|
||||||
) -> reqwest::RequestBuilder {
|
) -> reqwest::RequestBuilder {
|
||||||
let mut request = self.client.post(api_path(url, route)).json(json_request);
|
let mut request = client.post(api_path(url, route)).json(json_request);
|
||||||
|
if connection_close {
|
||||||
// Copy headers from original request (excluding content-type and content-length which are set by .json())
|
request = request.header("Connection", "close");
|
||||||
|
}
|
||||||
if let Some(headers) = headers {
|
if let Some(headers) = headers {
|
||||||
for (name, value) in headers.iter() {
|
for (name, value) in headers.iter() {
|
||||||
let name_str = name.as_str();
|
let name_lc = name.as_str().to_ascii_lowercase();
|
||||||
if name_str != "content-type" && name_str != "content-length" {
|
// Whitelist important end-to-end headers, skip hop-by-hop
|
||||||
// Skip headers with non-ASCII values
|
let forward = matches!(
|
||||||
if value.to_str().is_ok() {
|
name_lc.as_str(),
|
||||||
request = request.header(name, value);
|
"authorization" | "x-request-id" | "x-correlation-id"
|
||||||
|
) || name_lc.starts_with("x-request-id-");
|
||||||
|
if forward {
|
||||||
|
if let Ok(val) = value.to_str() {
|
||||||
|
request = request.header(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter {
|
|||||||
|
|
||||||
// Test prefill server's health_generate
|
// Test prefill server's health_generate
|
||||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||||
let prefill_result = self.client.get(&prefill_url).send().await;
|
let (prefill_result, decode_result) = tokio::join!(
|
||||||
|
self.client.get(&prefill_url).send(),
|
||||||
// Test decode server's health_generate
|
self.client
|
||||||
let decode_url = format!("{}/health_generate", decode.url());
|
.get(&format!("{}/health_generate", decode.url()))
|
||||||
let decode_result = self.client.get(&decode_url).send().await;
|
.send()
|
||||||
|
);
|
||||||
|
|
||||||
// Check results
|
// Check results
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
@@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter {
|
|||||||
decode.url()
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create optimized request with bootstrap fields
|
|
||||||
let batch_size = Self::get_generate_batch_size(body);
|
let batch_size = Self::get_generate_batch_size(body);
|
||||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
let original = match serde_json::to_value(body) {
|
||||||
Ok(json) => json,
|
Ok(v) => v,
|
||||||
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
|
};
|
||||||
|
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||||
|
Ok(v) => v,
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter {
|
|||||||
decode.url()
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create optimized request with bootstrap fields
|
|
||||||
let batch_size = Self::get_chat_batch_size(body);
|
let batch_size = Self::get_chat_batch_size(body);
|
||||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
let original = match serde_json::to_value(body) {
|
||||||
Ok(json) => json,
|
Ok(v) => v,
|
||||||
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
|
};
|
||||||
|
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||||
|
Ok(v) => v,
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter {
|
|||||||
decode.url()
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create optimized request with bootstrap fields
|
|
||||||
let batch_size = Self::get_completion_batch_size(body);
|
let batch_size = Self::get_completion_batch_size(body);
|
||||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
let original = match serde_json::to_value(body) {
|
||||||
Ok(json) => json,
|
Ok(v) => v,
|
||||||
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
|
};
|
||||||
|
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||||
|
Ok(v) => v,
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1771,6 +1846,7 @@ mod tests {
|
|||||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||||
load_monitor_handle: None,
|
load_monitor_handle: None,
|
||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
|
prefill_client: Client::new(),
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
_prefill_health_checker: None,
|
_prefill_health_checker: None,
|
||||||
_decode_health_checker: None,
|
_decode_health_checker: None,
|
||||||
|
|||||||
Reference in New Issue
Block a user