diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 404fe9904..ab22e1d9d 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -38,6 +38,8 @@ pub struct PDRouter { pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, pub client: Client, + // Dedicated client for prefill fire-and-forget (non-logprob) requests + pub prefill_client: Client, pub retry_config: RetryConfig, _prefill_health_checker: Option, _decode_health_checker: Option, @@ -255,6 +257,15 @@ impl PDRouter { let decode_health_checker = crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs); + // Build a dedicated prefill client for fire-and-forget semantics + let prefill_client = reqwest::Client::builder() + .pool_max_idle_per_host(0) + .http1_only() + .connect_timeout(Duration::from_millis(300)) + .timeout(Duration::from_secs(2)) + .build() + .map_err(|e| format!("Failed to build prefill client: {}", e))?; + Ok(PDRouter { prefill_workers, decode_workers, @@ -267,6 +278,7 @@ impl PDRouter { worker_loads, load_monitor_handle, client, + prefill_client, retry_config, _prefill_health_checker: Some(prefill_health_checker), _decode_health_checker: Some(decode_health_checker), @@ -365,41 +377,69 @@ impl PDRouter { None } - // Helper to create request with bootstrap fields - fn create_request_with_bootstrap( - request: &T, + // Helper to inject bootstrap fields into an existing JSON request value + fn inject_bootstrap_into_value( + mut original: Value, prefill_worker: &dyn Worker, batch_size: Option, - ) -> Result { - // Get bootstrap port from prefill worker + ) -> Result { let bootstrap_port = match prefill_worker.worker_type() { crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port, _ => None, }; let hostname = super::pd_types::get_hostname(prefill_worker.url()); - // Create optimized request with bootstrap fields - if let Some(batch_size) = batch_size { - // Batch request - let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap { - original: request, - bootstrap_host: vec![hostname; batch_size], - bootstrap_port: vec![bootstrap_port; batch_size], - bootstrap_room: (0..batch_size) - .map(|_| super::pd_types::generate_room_id()) - .collect(), - }; - serde_json::to_value(&request_with_bootstrap) + let obj = original + .as_object_mut() + .ok_or_else(|| "Request must be a JSON object".to_string())?; + + if let Some(n) = batch_size { + let mut hosts = Vec::with_capacity(n); + let mut ports = Vec::with_capacity(n); + let mut rooms = Vec::with_capacity(n); + for _ in 0..n { + hosts.push(hostname.clone()); + ports.push(bootstrap_port); + rooms.push(super::pd_types::generate_room_id()); + } + obj.insert( + "bootstrap_host".to_string(), + Value::Array(hosts.into_iter().map(serde_json::Value::from).collect()), + ); + obj.insert( + "bootstrap_port".to_string(), + Value::Array( + ports + .into_iter() + .map(|p| match p { + Some(v) => serde_json::Value::from(v), + None => Value::Null, + }) + .collect(), + ), + ); + obj.insert( + "bootstrap_room".to_string(), + Value::Array(rooms.into_iter().map(serde_json::Value::from).collect()), + ); } else { - // Single request - let request_with_bootstrap = super::pd_types::RequestWithBootstrap { - original: request, - bootstrap_host: hostname, - bootstrap_port, - bootstrap_room: super::pd_types::generate_room_id(), - }; - serde_json::to_value(&request_with_bootstrap) + obj.insert( + "bootstrap_host".to_string(), + serde_json::Value::from(hostname), + ); + obj.insert( + "bootstrap_port".to_string(), + match bootstrap_port { + Some(v) => serde_json::Value::from(v), + None => Value::Null, + }, + ); + obj.insert( + "bootstrap_room".to_string(), + serde_json::Value::from(super::pd_types::generate_room_id()), + ); } + Ok(original) } // Execute the dual dispatch to prefill and decode servers @@ -417,12 +457,15 @@ impl PDRouter { // Update load tracking for both workers let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); - // Build requests with headers - let prefill_request = - self.build_request_with_headers(prefill.url(), route, &json_request, headers); - - let decode_request = - self.build_request_with_headers(decode.url(), route, &json_request, headers); + // Build decode request with shared client + let decode_request = self.build_post_with_headers( + &self.client, + decode.url(), + route, + &json_request, + headers, + false, + ); // Send both requests concurrently debug!( @@ -432,6 +475,15 @@ impl PDRouter { ); if return_logprob { + // Build prefill request with shared client when we need response body + let prefill_request = self.build_post_with_headers( + &self.client, + prefill.url(), + route, + &json_request, + headers, + false, + ); // When we need logprobs, wait for both responses let (prefill_result, decode_result) = tokio::join!(prefill_request.send(), decode_request.send()); @@ -525,19 +577,27 @@ impl PDRouter { } else { // When we don't need logprobs, only wait for decode response // Send both requests concurrently but don't wait for prefill - // Add headers to minimize response size when we don't need the body - let prefill_future = prefill_request.header("Connection", "close").send(); + // Use dedicated prefill client with Connection: close + let prefill_future = self + .build_post_with_headers( + &self.prefill_client, + prefill.url(), + route, + &json_request, + headers, + true, + ) + .send(); let decode_future = decode_request.send(); tokio::spawn(async move { if let Ok(response) = prefill_future.await { - // Consume with a short timeout to free connection quickly - let consume_future = async { - let _ = response.bytes().await; - }; - - // Give it 100ms to consume, then abandon - let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await; + // Consume at most one small chunk with a very short timeout to advance flow control + let _ = tokio::time::timeout(Duration::from_millis(20), async { + let mut s = response.bytes_stream(); + let _ = s.next().await; + }) + .await; } }); @@ -879,29 +939,34 @@ impl PDRouter { Ok((prefill_status, prefill_body)) } - // Helper to build a request with headers copied from the original request - fn build_request_with_headers( + fn build_post_with_headers( &self, + client: &reqwest::Client, url: &str, route: &str, - json_request: &Value, + json_request: &serde_json::Value, headers: Option<&HeaderMap>, + connection_close: bool, ) -> reqwest::RequestBuilder { - let mut request = self.client.post(api_path(url, route)).json(json_request); - - // Copy headers from original request (excluding content-type and content-length which are set by .json()) + let mut request = client.post(api_path(url, route)).json(json_request); + if connection_close { + request = request.header("Connection", "close"); + } if let Some(headers) = headers { for (name, value) in headers.iter() { - let name_str = name.as_str(); - if name_str != "content-type" && name_str != "content-length" { - // Skip headers with non-ASCII values - if value.to_str().is_ok() { - request = request.header(name, value); + let name_lc = name.as_str().to_ascii_lowercase(); + // Whitelist important end-to-end headers, skip hop-by-hop + let forward = matches!( + name_lc.as_str(), + "authorization" | "x-request-id" | "x-correlation-id" + ) || name_lc.starts_with("x-request-id-"); + if forward { + if let Ok(val) = value.to_str() { + request = request.header(name, val); } } } } - request } @@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter { // Test prefill server's health_generate let prefill_url = format!("{}/health_generate", prefill.url()); - let prefill_result = self.client.get(&prefill_url).send().await; - - // Test decode server's health_generate - let decode_url = format!("{}/health_generate", decode.url()); - let decode_result = self.client.get(&decode_url).send().await; + let (prefill_result, decode_result) = tokio::join!( + self.client.get(&prefill_url).send(), + self.client + .get(&format!("{}/health_generate", decode.url())) + .send() + ); // Check results let mut errors = Vec::new(); @@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter { decode.url() ); - // Create optimized request with bootstrap fields let batch_size = Self::get_generate_batch_size(body); - let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { - Ok(json) => json, + let original = match serde_json::to_value(body) { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { + Ok(v) => v, Err(e) => return Self::handle_serialization_error(e), }; @@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter { decode.url() ); - // Create optimized request with bootstrap fields let batch_size = Self::get_chat_batch_size(body); - let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { - Ok(json) => json, + let original = match serde_json::to_value(body) { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { + Ok(v) => v, Err(e) => return Self::handle_serialization_error(e), }; @@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter { decode.url() ); - // Create optimized request with bootstrap fields let batch_size = Self::get_completion_batch_size(body); - let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { - Ok(json) => json, + let original = match serde_json::to_value(body) { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) { + Ok(v) => v, Err(e) => return Self::handle_serialization_error(e), }; @@ -1771,6 +1846,7 @@ mod tests { worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), load_monitor_handle: None, client: Client::new(), + prefill_client: Client::new(), retry_config: RetryConfig::default(), _prefill_health_checker: None, _decode_health_checker: None,