[router] introduce prefill response draining for http compliance (#9281)
This commit is contained in:
@@ -29,6 +29,7 @@ use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
@@ -49,6 +50,8 @@ pub struct PDRouter {
|
||||
pub circuit_breaker_config: CircuitBreakerConfig,
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
// Channel for sending prefill responses to background workers for draining
|
||||
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||
}
|
||||
|
||||
// Request context for PD router operations
|
||||
@@ -501,6 +504,75 @@ impl PDRouter {
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
|
||||
|
||||
// Create bounded channel for prefill response draining
|
||||
// Larger buffer for high concurrency scenarios
|
||||
let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::<reqwest::Response>(2000);
|
||||
|
||||
// Spawn a coordinator with limited concurrent drain tasks
|
||||
// This prevents unbounded task spawning under extreme load
|
||||
tokio::spawn(async move {
|
||||
info!("Prefill drain coordinator started");
|
||||
|
||||
// Use a semaphore to limit concurrent drain operations
|
||||
let max_concurrent_drains = 100;
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains));
|
||||
|
||||
while let Some(response) = prefill_drain_rx.recv().await {
|
||||
let permit = semaphore.clone().acquire_owned().await;
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
// Spawn a task to drain this response
|
||||
tokio::spawn(async move {
|
||||
let url = response.url().to_string();
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
error!("Prefill drain: error status={} url={}", status, url);
|
||||
RouterMetrics::record_pd_prefill_error(&url);
|
||||
}
|
||||
|
||||
// Drain the response body efficiently
|
||||
// Use streaming to avoid loading entire body into memory
|
||||
let start = std::time::Instant::now();
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut bytes_drained = 0;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => bytes_drained += chunk.len(),
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"Prefill drain: error streaming url={} error={}",
|
||||
url, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed > Duration::from_millis(100) {
|
||||
// Only log slow drains
|
||||
debug!(
|
||||
"Prefill drain: slow drain {} bytes from {} in {:?}",
|
||||
bytes_drained, url, elapsed
|
||||
);
|
||||
}
|
||||
|
||||
// Permit is automatically released when dropped
|
||||
drop(permit);
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Semaphore closed, shutting down
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Prefill drain coordinator shutting down");
|
||||
});
|
||||
|
||||
Ok(PDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
@@ -512,6 +584,7 @@ impl PDRouter {
|
||||
load_monitor_handle,
|
||||
client,
|
||||
prefill_client,
|
||||
prefill_drain_tx,
|
||||
retry_config,
|
||||
circuit_breaker_config: core_cb_config,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
@@ -702,11 +775,9 @@ impl PDRouter {
|
||||
.execute_dual_dispatch_internal(
|
||||
headers,
|
||||
json_request,
|
||||
context.route,
|
||||
context,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
context.is_stream,
|
||||
context.return_logprob,
|
||||
start_time,
|
||||
)
|
||||
.await;
|
||||
@@ -734,16 +805,13 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn execute_dual_dispatch_internal(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
json_request: Value,
|
||||
route: &str,
|
||||
context: PDRequestContext,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> Response {
|
||||
// Update load tracking for both workers
|
||||
@@ -753,7 +821,7 @@ impl PDRouter {
|
||||
let decode_request = self.build_post_with_headers(
|
||||
&self.client,
|
||||
decode.url(),
|
||||
route,
|
||||
context.route,
|
||||
&json_request,
|
||||
headers,
|
||||
false,
|
||||
@@ -766,12 +834,12 @@ impl PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
if return_logprob {
|
||||
if context.return_logprob {
|
||||
// Build prefill request with shared client when we need response body
|
||||
let prefill_request = self.build_post_with_headers(
|
||||
&self.client,
|
||||
prefill.url(),
|
||||
route,
|
||||
context.route,
|
||||
&json_request,
|
||||
headers,
|
||||
false,
|
||||
@@ -783,8 +851,8 @@ impl PDRouter {
|
||||
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
RouterMetrics::record_pd_request_duration(route, duration);
|
||||
RouterMetrics::record_pd_request(route);
|
||||
RouterMetrics::record_pd_request_duration(context.route, duration);
|
||||
RouterMetrics::record_pd_request(context.route);
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
|
||||
@@ -818,14 +886,18 @@ impl PDRouter {
|
||||
|
||||
// Process prefill response for logprobs
|
||||
let prefill_body = match self
|
||||
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
|
||||
.process_prefill_response(
|
||||
prefill_result,
|
||||
prefill.url(),
|
||||
context.return_logprob,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((_, body)) => body,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
if is_stream {
|
||||
if context.is_stream {
|
||||
// Streaming response with logprobs
|
||||
let prefill_logprobs = prefill_body
|
||||
.as_ref()
|
||||
@@ -841,7 +913,7 @@ impl PDRouter {
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
prefill_logprobs,
|
||||
return_logprob,
|
||||
context.return_logprob,
|
||||
None,
|
||||
Some(response_headers),
|
||||
)
|
||||
@@ -850,7 +922,7 @@ impl PDRouter {
|
||||
self.process_non_streaming_response(
|
||||
res,
|
||||
status,
|
||||
return_logprob,
|
||||
context.return_logprob,
|
||||
prefill_body,
|
||||
)
|
||||
.await
|
||||
@@ -878,7 +950,7 @@ impl PDRouter {
|
||||
.build_post_with_headers(
|
||||
&self.prefill_client,
|
||||
prefill.url(),
|
||||
route,
|
||||
context.route,
|
||||
&json_request,
|
||||
headers,
|
||||
true,
|
||||
@@ -886,11 +958,41 @@ impl PDRouter {
|
||||
.send();
|
||||
let decode_future = decode_request.send();
|
||||
|
||||
// Send prefill response to background worker for draining
|
||||
// This ensures HTTP compliance without blocking
|
||||
let drain_tx = self.prefill_drain_tx.clone();
|
||||
let prefill_url = prefill.url().to_string();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(response) = prefill_future.await {
|
||||
// Consume the entire response body to maintain HTTP compliance
|
||||
// This runs in the background and won't block the decode response
|
||||
let _ = response.bytes().await;
|
||||
// Try to send to drain worker
|
||||
// If channel is full (under extreme load), drain inline as fallback
|
||||
match drain_tx.try_send(response) {
|
||||
Ok(_) => {
|
||||
// Successfully queued for draining
|
||||
debug!("Prefill response queued for draining");
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(response)) => {
|
||||
// Channel full - drain inline as fallback
|
||||
warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url);
|
||||
RouterMetrics::record_pd_prefill_error(&prefill_url);
|
||||
|
||||
// Drain inline with timeout to prevent blocking too long
|
||||
let drain_future = async {
|
||||
let mut stream = response.bytes_stream();
|
||||
while stream.next().await.is_some() {
|
||||
// Just drain
|
||||
}
|
||||
};
|
||||
|
||||
match tokio::time::timeout(Duration::from_secs(1), drain_future).await {
|
||||
Ok(_) => debug!("Inline drain completed for {}", prefill_url),
|
||||
Err(_) => error!("Inline drain timeout for {}", prefill_url),
|
||||
}
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
error!("Prefill drain channel closed!");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -900,8 +1002,8 @@ impl PDRouter {
|
||||
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
RouterMetrics::record_pd_request_duration(route, duration);
|
||||
RouterMetrics::record_pd_request(route);
|
||||
RouterMetrics::record_pd_request_duration(context.route, duration);
|
||||
RouterMetrics::record_pd_request(context.route);
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
|
||||
@@ -928,7 +1030,7 @@ impl PDRouter {
|
||||
(status, format!("Decode server error: {}", e)).into_response()
|
||||
}
|
||||
}
|
||||
} else if is_stream {
|
||||
} else if context.is_stream {
|
||||
// Streaming response without logprobs - direct passthrough
|
||||
let decode_url = decode.url().to_string();
|
||||
let response_headers =
|
||||
@@ -1280,10 +1382,10 @@ impl PDRouter {
|
||||
|
||||
fn build_post_with_headers(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
client: &Client,
|
||||
url: &str,
|
||||
route: &str,
|
||||
json_request: &serde_json::Value,
|
||||
json_request: &Value,
|
||||
headers: Option<&HeaderMap>,
|
||||
connection_close: bool,
|
||||
) -> reqwest::RequestBuilder {
|
||||
@@ -1894,6 +1996,7 @@ mod tests {
|
||||
load_monitor_handle: None,
|
||||
client: Client::new(),
|
||||
prefill_client: Client::new(),
|
||||
prefill_drain_tx: mpsc::channel(100).0,
|
||||
retry_config: RetryConfig::default(),
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
_prefill_health_checker: None,
|
||||
|
||||
Reference in New Issue
Block a user