[router] dedicated prefill HTTP client and request-path optimizations (#8923)
This commit is contained in:
@@ -38,6 +38,8 @@ pub struct PDRouter {
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub client: Client,
|
||||
// Dedicated client for prefill fire-and-forget (non-logprob) requests
|
||||
pub prefill_client: Client,
|
||||
pub retry_config: RetryConfig,
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
@@ -255,6 +257,15 @@ impl PDRouter {
|
||||
let decode_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||
|
||||
// Build a dedicated prefill client for fire-and-forget semantics
|
||||
let prefill_client = reqwest::Client::builder()
|
||||
.pool_max_idle_per_host(0)
|
||||
.http1_only()
|
||||
.connect_timeout(Duration::from_millis(300))
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
|
||||
|
||||
Ok(PDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
@@ -267,6 +278,7 @@ impl PDRouter {
|
||||
worker_loads,
|
||||
load_monitor_handle,
|
||||
client,
|
||||
prefill_client,
|
||||
retry_config,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
@@ -365,41 +377,69 @@ impl PDRouter {
|
||||
None
|
||||
}
|
||||
|
||||
// Helper to create request with bootstrap fields
|
||||
fn create_request_with_bootstrap<T: serde::Serialize>(
|
||||
request: &T,
|
||||
// Helper to inject bootstrap fields into an existing JSON request value
|
||||
fn inject_bootstrap_into_value(
|
||||
mut original: Value,
|
||||
prefill_worker: &dyn Worker,
|
||||
batch_size: Option<usize>,
|
||||
) -> Result<serde_json::Value, serde_json::Error> {
|
||||
// Get bootstrap port from prefill worker
|
||||
) -> Result<Value, String> {
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let hostname = super::pd_types::get_hostname(prefill_worker.url());
|
||||
|
||||
// Create optimized request with bootstrap fields
|
||||
if let Some(batch_size) = batch_size {
|
||||
// Batch request
|
||||
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
|
||||
original: request,
|
||||
bootstrap_host: vec![hostname; batch_size],
|
||||
bootstrap_port: vec![bootstrap_port; batch_size],
|
||||
bootstrap_room: (0..batch_size)
|
||||
.map(|_| super::pd_types::generate_room_id())
|
||||
.collect(),
|
||||
};
|
||||
serde_json::to_value(&request_with_bootstrap)
|
||||
let obj = original
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "Request must be a JSON object".to_string())?;
|
||||
|
||||
if let Some(n) = batch_size {
|
||||
let mut hosts = Vec::with_capacity(n);
|
||||
let mut ports = Vec::with_capacity(n);
|
||||
let mut rooms = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
hosts.push(hostname.clone());
|
||||
ports.push(bootstrap_port);
|
||||
rooms.push(super::pd_types::generate_room_id());
|
||||
}
|
||||
obj.insert(
|
||||
"bootstrap_host".to_string(),
|
||||
Value::Array(hosts.into_iter().map(serde_json::Value::from).collect()),
|
||||
);
|
||||
obj.insert(
|
||||
"bootstrap_port".to_string(),
|
||||
Value::Array(
|
||||
ports
|
||||
.into_iter()
|
||||
.map(|p| match p {
|
||||
Some(v) => serde_json::Value::from(v),
|
||||
None => Value::Null,
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
obj.insert(
|
||||
"bootstrap_room".to_string(),
|
||||
Value::Array(rooms.into_iter().map(serde_json::Value::from).collect()),
|
||||
);
|
||||
} else {
|
||||
// Single request
|
||||
let request_with_bootstrap = super::pd_types::RequestWithBootstrap {
|
||||
original: request,
|
||||
bootstrap_host: hostname,
|
||||
bootstrap_port,
|
||||
bootstrap_room: super::pd_types::generate_room_id(),
|
||||
};
|
||||
serde_json::to_value(&request_with_bootstrap)
|
||||
obj.insert(
|
||||
"bootstrap_host".to_string(),
|
||||
serde_json::Value::from(hostname),
|
||||
);
|
||||
obj.insert(
|
||||
"bootstrap_port".to_string(),
|
||||
match bootstrap_port {
|
||||
Some(v) => serde_json::Value::from(v),
|
||||
None => Value::Null,
|
||||
},
|
||||
);
|
||||
obj.insert(
|
||||
"bootstrap_room".to_string(),
|
||||
serde_json::Value::from(super::pd_types::generate_room_id()),
|
||||
);
|
||||
}
|
||||
Ok(original)
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
@@ -417,12 +457,15 @@ impl PDRouter {
|
||||
// Update load tracking for both workers
|
||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||
|
||||
// Build requests with headers
|
||||
let prefill_request =
|
||||
self.build_request_with_headers(prefill.url(), route, &json_request, headers);
|
||||
|
||||
let decode_request =
|
||||
self.build_request_with_headers(decode.url(), route, &json_request, headers);
|
||||
// Build decode request with shared client
|
||||
let decode_request = self.build_post_with_headers(
|
||||
&self.client,
|
||||
decode.url(),
|
||||
route,
|
||||
&json_request,
|
||||
headers,
|
||||
false,
|
||||
);
|
||||
|
||||
// Send both requests concurrently
|
||||
debug!(
|
||||
@@ -432,6 +475,15 @@ impl PDRouter {
|
||||
);
|
||||
|
||||
if return_logprob {
|
||||
// Build prefill request with shared client when we need response body
|
||||
let prefill_request = self.build_post_with_headers(
|
||||
&self.client,
|
||||
prefill.url(),
|
||||
route,
|
||||
&json_request,
|
||||
headers,
|
||||
false,
|
||||
);
|
||||
// When we need logprobs, wait for both responses
|
||||
let (prefill_result, decode_result) =
|
||||
tokio::join!(prefill_request.send(), decode_request.send());
|
||||
@@ -525,19 +577,27 @@ impl PDRouter {
|
||||
} else {
|
||||
// When we don't need logprobs, only wait for decode response
|
||||
// Send both requests concurrently but don't wait for prefill
|
||||
// Add headers to minimize response size when we don't need the body
|
||||
let prefill_future = prefill_request.header("Connection", "close").send();
|
||||
// Use dedicated prefill client with Connection: close
|
||||
let prefill_future = self
|
||||
.build_post_with_headers(
|
||||
&self.prefill_client,
|
||||
prefill.url(),
|
||||
route,
|
||||
&json_request,
|
||||
headers,
|
||||
true,
|
||||
)
|
||||
.send();
|
||||
let decode_future = decode_request.send();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Ok(response) = prefill_future.await {
|
||||
// Consume with a short timeout to free connection quickly
|
||||
let consume_future = async {
|
||||
let _ = response.bytes().await;
|
||||
};
|
||||
|
||||
// Give it 100ms to consume, then abandon
|
||||
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
|
||||
// Consume at most one small chunk with a very short timeout to advance flow control
|
||||
let _ = tokio::time::timeout(Duration::from_millis(20), async {
|
||||
let mut s = response.bytes_stream();
|
||||
let _ = s.next().await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -879,29 +939,34 @@ impl PDRouter {
|
||||
Ok((prefill_status, prefill_body))
|
||||
}
|
||||
|
||||
// Helper to build a request with headers copied from the original request
|
||||
fn build_request_with_headers(
|
||||
fn build_post_with_headers(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
route: &str,
|
||||
json_request: &Value,
|
||||
json_request: &serde_json::Value,
|
||||
headers: Option<&HeaderMap>,
|
||||
connection_close: bool,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let mut request = self.client.post(api_path(url, route)).json(json_request);
|
||||
|
||||
// Copy headers from original request (excluding content-type and content-length which are set by .json())
|
||||
let mut request = client.post(api_path(url, route)).json(json_request);
|
||||
if connection_close {
|
||||
request = request.header("Connection", "close");
|
||||
}
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers.iter() {
|
||||
let name_str = name.as_str();
|
||||
if name_str != "content-type" && name_str != "content-length" {
|
||||
// Skip headers with non-ASCII values
|
||||
if value.to_str().is_ok() {
|
||||
request = request.header(name, value);
|
||||
let name_lc = name.as_str().to_ascii_lowercase();
|
||||
// Whitelist important end-to-end headers, skip hop-by-hop
|
||||
let forward = matches!(
|
||||
name_lc.as_str(),
|
||||
"authorization" | "x-request-id" | "x-correlation-id"
|
||||
) || name_lc.starts_with("x-request-id-");
|
||||
if forward {
|
||||
if let Ok(val) = value.to_str() {
|
||||
request = request.header(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
@@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter {
|
||||
|
||||
// Test prefill server's health_generate
|
||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||
let prefill_result = self.client.get(&prefill_url).send().await;
|
||||
|
||||
// Test decode server's health_generate
|
||||
let decode_url = format!("{}/health_generate", decode.url());
|
||||
let decode_result = self.client.get(&decode_url).send().await;
|
||||
let (prefill_result, decode_result) = tokio::join!(
|
||||
self.client.get(&prefill_url).send(),
|
||||
self.client
|
||||
.get(&format!("{}/health_generate", decode.url()))
|
||||
.send()
|
||||
);
|
||||
|
||||
// Check results
|
||||
let mut errors = Vec::new();
|
||||
@@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_generate_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
let original = match serde_json::to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
@@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_chat_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
let original = match serde_json::to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
@@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_completion_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
let original = match serde_json::to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
@@ -1771,6 +1846,7 @@ mod tests {
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
load_monitor_handle: None,
|
||||
client: Client::new(),
|
||||
prefill_client: Client::new(),
|
||||
retry_config: RetryConfig::default(),
|
||||
_prefill_health_checker: None,
|
||||
_decode_health_checker: None,
|
||||
|
||||
Reference in New Issue
Block a user