From b341b7dbce705153a0d05fbbad521ae5cc648328 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 17 Aug 2025 14:23:04 -0700 Subject: [PATCH] [router] introduce prefill response draining for http compliance (#9281) --- sgl-router/src/routers/pd_router.rs | 153 +++++++++++++++++++++++----- 1 file changed, 128 insertions(+), 25 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 8a1a407a3..0d70f4ab9 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -29,6 +29,7 @@ use serde_json::Value; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; +use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; @@ -49,6 +50,8 @@ pub struct PDRouter { pub circuit_breaker_config: CircuitBreakerConfig, _prefill_health_checker: Option, _decode_health_checker: Option, + // Channel for sending prefill responses to background workers for draining + prefill_drain_tx: mpsc::Sender, } // Request context for PD router operations @@ -501,6 +504,75 @@ impl PDRouter { .build() .map_err(|e| format!("Failed to build prefill client: {}", e))?; + // Create bounded channel for prefill response draining + // Larger buffer for high concurrency scenarios + let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::(2000); + + // Spawn a coordinator with limited concurrent drain tasks + // This prevents unbounded task spawning under extreme load + tokio::spawn(async move { + info!("Prefill drain coordinator started"); + + // Use a semaphore to limit concurrent drain operations + let max_concurrent_drains = 100; + let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains)); + + while let Some(response) = prefill_drain_rx.recv().await { + let permit = semaphore.clone().acquire_owned().await; + + match permit { + Ok(permit) => { + // Spawn a task to drain this response + tokio::spawn(async move { + let url = response.url().to_string(); + let status = response.status(); + + if !status.is_success() { + error!("Prefill drain: error status={} url={}", status, url); + RouterMetrics::record_pd_prefill_error(&url); + } + + // Drain the response body efficiently + // Use streaming to avoid loading entire body into memory + let start = std::time::Instant::now(); + let mut stream = response.bytes_stream(); + let mut bytes_drained = 0; + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => bytes_drained += chunk.len(), + Err(e) => { + debug!( + "Prefill drain: error streaming url={} error={}", + url, e + ); + break; + } + } + } + + let elapsed = start.elapsed(); + if elapsed > Duration::from_millis(100) { + // Only log slow drains + debug!( + "Prefill drain: slow drain {} bytes from {} in {:?}", + bytes_drained, url, elapsed + ); + } + + // Permit is automatically released when dropped + drop(permit); + }); + } + Err(_) => { + // Semaphore closed, shutting down + break; + } + } + } + info!("Prefill drain coordinator shutting down"); + }); + Ok(PDRouter { prefill_workers, decode_workers, @@ -512,6 +584,7 @@ impl PDRouter { load_monitor_handle, client, prefill_client, + prefill_drain_tx, retry_config, circuit_breaker_config: core_cb_config, _prefill_health_checker: Some(prefill_health_checker), @@ -702,11 +775,9 @@ impl PDRouter { .execute_dual_dispatch_internal( headers, json_request, - context.route, + context, prefill.as_ref(), decode.as_ref(), - context.is_stream, - context.return_logprob, start_time, ) .await; @@ -734,16 +805,13 @@ impl PDRouter { } // Internal method that performs the actual dual dispatch (without retry logic) - #[allow(clippy::too_many_arguments)] async fn execute_dual_dispatch_internal( &self, headers: Option<&HeaderMap>, json_request: Value, - route: &str, + context: PDRequestContext, prefill: &dyn Worker, decode: &dyn Worker, - is_stream: bool, - return_logprob: bool, start_time: Instant, ) -> Response { // Update load tracking for both workers @@ -753,7 +821,7 @@ impl PDRouter { let decode_request = self.build_post_with_headers( &self.client, decode.url(), - route, + context.route, &json_request, headers, false, @@ -766,12 +834,12 @@ impl PDRouter { decode.url() ); - if return_logprob { + if context.return_logprob { // Build prefill request with shared client when we need response body let prefill_request = self.build_post_with_headers( &self.client, prefill.url(), - route, + context.route, &json_request, headers, false, @@ -783,8 +851,8 @@ impl PDRouter { // Update metrics let duration = start_time.elapsed(); - RouterMetrics::record_pd_request_duration(route, duration); - RouterMetrics::record_pd_request(route); + RouterMetrics::record_pd_request_duration(context.route, duration); + RouterMetrics::record_pd_request(context.route); RouterMetrics::record_pd_prefill_request(prefill.url()); RouterMetrics::record_pd_decode_request(decode.url()); @@ -818,14 +886,18 @@ impl PDRouter { // Process prefill response for logprobs let prefill_body = match self - .process_prefill_response(prefill_result, prefill.url(), return_logprob) + .process_prefill_response( + prefill_result, + prefill.url(), + context.return_logprob, + ) .await { Ok((_, body)) => body, Err(error_response) => return error_response, }; - if is_stream { + if context.is_stream { // Streaming response with logprobs let prefill_logprobs = prefill_body .as_ref() @@ -841,7 +913,7 @@ impl PDRouter { res.bytes_stream(), status, prefill_logprobs, - return_logprob, + context.return_logprob, None, Some(response_headers), ) @@ -850,7 +922,7 @@ impl PDRouter { self.process_non_streaming_response( res, status, - return_logprob, + context.return_logprob, prefill_body, ) .await @@ -878,7 +950,7 @@ impl PDRouter { .build_post_with_headers( &self.prefill_client, prefill.url(), - route, + context.route, &json_request, headers, true, @@ -886,11 +958,41 @@ impl PDRouter { .send(); let decode_future = decode_request.send(); + // Send prefill response to background worker for draining + // This ensures HTTP compliance without blocking + let drain_tx = self.prefill_drain_tx.clone(); + let prefill_url = prefill.url().to_string(); tokio::spawn(async move { if let Ok(response) = prefill_future.await { - // Consume the entire response body to maintain HTTP compliance - // This runs in the background and won't block the decode response - let _ = response.bytes().await; + // Try to send to drain worker + // If channel is full (under extreme load), drain inline as fallback + match drain_tx.try_send(response) { + Ok(_) => { + // Successfully queued for draining + debug!("Prefill response queued for draining"); + } + Err(mpsc::error::TrySendError::Full(response)) => { + // Channel full - drain inline as fallback + warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url); + RouterMetrics::record_pd_prefill_error(&prefill_url); + + // Drain inline with timeout to prevent blocking too long + let drain_future = async { + let mut stream = response.bytes_stream(); + while stream.next().await.is_some() { + // Just drain + } + }; + + match tokio::time::timeout(Duration::from_secs(1), drain_future).await { + Ok(_) => debug!("Inline drain completed for {}", prefill_url), + Err(_) => error!("Inline drain timeout for {}", prefill_url), + } + } + Err(mpsc::error::TrySendError::Closed(_)) => { + error!("Prefill drain channel closed!"); + } + } } }); @@ -900,8 +1002,8 @@ impl PDRouter { // Update metrics let duration = start_time.elapsed(); - RouterMetrics::record_pd_request_duration(route, duration); - RouterMetrics::record_pd_request(route); + RouterMetrics::record_pd_request_duration(context.route, duration); + RouterMetrics::record_pd_request(context.route); RouterMetrics::record_pd_prefill_request(prefill.url()); RouterMetrics::record_pd_decode_request(decode.url()); @@ -928,7 +1030,7 @@ impl PDRouter { (status, format!("Decode server error: {}", e)).into_response() } } - } else if is_stream { + } else if context.is_stream { // Streaming response without logprobs - direct passthrough let decode_url = decode.url().to_string(); let response_headers = @@ -1280,10 +1382,10 @@ impl PDRouter { fn build_post_with_headers( &self, - client: &reqwest::Client, + client: &Client, url: &str, route: &str, - json_request: &serde_json::Value, + json_request: &Value, headers: Option<&HeaderMap>, connection_close: bool, ) -> reqwest::RequestBuilder { @@ -1894,6 +1996,7 @@ mod tests { load_monitor_handle: None, client: Client::new(), prefill_client: Client::new(), + prefill_drain_tx: mpsc::channel(100).0, retry_config: RetryConfig::default(), circuit_breaker_config: CircuitBreakerConfig::default(), _prefill_health_checker: None,