[router] introduce prefill response draining for http compliance (#9281)
This commit is contained in:
@@ -29,6 +29,7 @@ use serde_json::Value;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
@@ -49,6 +50,8 @@ pub struct PDRouter {
|
|||||||
pub circuit_breaker_config: CircuitBreakerConfig,
|
pub circuit_breaker_config: CircuitBreakerConfig,
|
||||||
_prefill_health_checker: Option<HealthChecker>,
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
|
// Channel for sending prefill responses to background workers for draining
|
||||||
|
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request context for PD router operations
|
// Request context for PD router operations
|
||||||
@@ -501,6 +504,75 @@ impl PDRouter {
|
|||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
|
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
|
||||||
|
|
||||||
|
// Create bounded channel for prefill response draining
|
||||||
|
// Larger buffer for high concurrency scenarios
|
||||||
|
let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::<reqwest::Response>(2000);
|
||||||
|
|
||||||
|
// Spawn a coordinator with limited concurrent drain tasks
|
||||||
|
// This prevents unbounded task spawning under extreme load
|
||||||
|
tokio::spawn(async move {
|
||||||
|
info!("Prefill drain coordinator started");
|
||||||
|
|
||||||
|
// Use a semaphore to limit concurrent drain operations
|
||||||
|
let max_concurrent_drains = 100;
|
||||||
|
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains));
|
||||||
|
|
||||||
|
while let Some(response) = prefill_drain_rx.recv().await {
|
||||||
|
let permit = semaphore.clone().acquire_owned().await;
|
||||||
|
|
||||||
|
match permit {
|
||||||
|
Ok(permit) => {
|
||||||
|
// Spawn a task to drain this response
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let url = response.url().to_string();
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
error!("Prefill drain: error status={} url={}", status, url);
|
||||||
|
RouterMetrics::record_pd_prefill_error(&url);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the response body efficiently
|
||||||
|
// Use streaming to avoid loading entire body into memory
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut bytes_drained = 0;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => bytes_drained += chunk.len(),
|
||||||
|
Err(e) => {
|
||||||
|
debug!(
|
||||||
|
"Prefill drain: error streaming url={} error={}",
|
||||||
|
url, e
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
if elapsed > Duration::from_millis(100) {
|
||||||
|
// Only log slow drains
|
||||||
|
debug!(
|
||||||
|
"Prefill drain: slow drain {} bytes from {} in {:?}",
|
||||||
|
bytes_drained, url, elapsed
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Permit is automatically released when dropped
|
||||||
|
drop(permit);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Semaphore closed, shutting down
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("Prefill drain coordinator shutting down");
|
||||||
|
});
|
||||||
|
|
||||||
Ok(PDRouter {
|
Ok(PDRouter {
|
||||||
prefill_workers,
|
prefill_workers,
|
||||||
decode_workers,
|
decode_workers,
|
||||||
@@ -512,6 +584,7 @@ impl PDRouter {
|
|||||||
load_monitor_handle,
|
load_monitor_handle,
|
||||||
client,
|
client,
|
||||||
prefill_client,
|
prefill_client,
|
||||||
|
prefill_drain_tx,
|
||||||
retry_config,
|
retry_config,
|
||||||
circuit_breaker_config: core_cb_config,
|
circuit_breaker_config: core_cb_config,
|
||||||
_prefill_health_checker: Some(prefill_health_checker),
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
@@ -702,11 +775,9 @@ impl PDRouter {
|
|||||||
.execute_dual_dispatch_internal(
|
.execute_dual_dispatch_internal(
|
||||||
headers,
|
headers,
|
||||||
json_request,
|
json_request,
|
||||||
context.route,
|
context,
|
||||||
prefill.as_ref(),
|
prefill.as_ref(),
|
||||||
decode.as_ref(),
|
decode.as_ref(),
|
||||||
context.is_stream,
|
|
||||||
context.return_logprob,
|
|
||||||
start_time,
|
start_time,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@@ -734,16 +805,13 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Internal method that performs the actual dual dispatch (without retry logic)
|
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
async fn execute_dual_dispatch_internal(
|
async fn execute_dual_dispatch_internal(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
json_request: Value,
|
json_request: Value,
|
||||||
route: &str,
|
context: PDRequestContext,
|
||||||
prefill: &dyn Worker,
|
prefill: &dyn Worker,
|
||||||
decode: &dyn Worker,
|
decode: &dyn Worker,
|
||||||
is_stream: bool,
|
|
||||||
return_logprob: bool,
|
|
||||||
start_time: Instant,
|
start_time: Instant,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Update load tracking for both workers
|
// Update load tracking for both workers
|
||||||
@@ -753,7 +821,7 @@ impl PDRouter {
|
|||||||
let decode_request = self.build_post_with_headers(
|
let decode_request = self.build_post_with_headers(
|
||||||
&self.client,
|
&self.client,
|
||||||
decode.url(),
|
decode.url(),
|
||||||
route,
|
context.route,
|
||||||
&json_request,
|
&json_request,
|
||||||
headers,
|
headers,
|
||||||
false,
|
false,
|
||||||
@@ -766,12 +834,12 @@ impl PDRouter {
|
|||||||
decode.url()
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
if return_logprob {
|
if context.return_logprob {
|
||||||
// Build prefill request with shared client when we need response body
|
// Build prefill request with shared client when we need response body
|
||||||
let prefill_request = self.build_post_with_headers(
|
let prefill_request = self.build_post_with_headers(
|
||||||
&self.client,
|
&self.client,
|
||||||
prefill.url(),
|
prefill.url(),
|
||||||
route,
|
context.route,
|
||||||
&json_request,
|
&json_request,
|
||||||
headers,
|
headers,
|
||||||
false,
|
false,
|
||||||
@@ -783,8 +851,8 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Update metrics
|
// Update metrics
|
||||||
let duration = start_time.elapsed();
|
let duration = start_time.elapsed();
|
||||||
RouterMetrics::record_pd_request_duration(route, duration);
|
RouterMetrics::record_pd_request_duration(context.route, duration);
|
||||||
RouterMetrics::record_pd_request(route);
|
RouterMetrics::record_pd_request(context.route);
|
||||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||||
RouterMetrics::record_pd_decode_request(decode.url());
|
RouterMetrics::record_pd_decode_request(decode.url());
|
||||||
|
|
||||||
@@ -818,14 +886,18 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Process prefill response for logprobs
|
// Process prefill response for logprobs
|
||||||
let prefill_body = match self
|
let prefill_body = match self
|
||||||
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
|
.process_prefill_response(
|
||||||
|
prefill_result,
|
||||||
|
prefill.url(),
|
||||||
|
context.return_logprob,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok((_, body)) => body,
|
Ok((_, body)) => body,
|
||||||
Err(error_response) => return error_response,
|
Err(error_response) => return error_response,
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_stream {
|
if context.is_stream {
|
||||||
// Streaming response with logprobs
|
// Streaming response with logprobs
|
||||||
let prefill_logprobs = prefill_body
|
let prefill_logprobs = prefill_body
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@@ -841,7 +913,7 @@ impl PDRouter {
|
|||||||
res.bytes_stream(),
|
res.bytes_stream(),
|
||||||
status,
|
status,
|
||||||
prefill_logprobs,
|
prefill_logprobs,
|
||||||
return_logprob,
|
context.return_logprob,
|
||||||
None,
|
None,
|
||||||
Some(response_headers),
|
Some(response_headers),
|
||||||
)
|
)
|
||||||
@@ -850,7 +922,7 @@ impl PDRouter {
|
|||||||
self.process_non_streaming_response(
|
self.process_non_streaming_response(
|
||||||
res,
|
res,
|
||||||
status,
|
status,
|
||||||
return_logprob,
|
context.return_logprob,
|
||||||
prefill_body,
|
prefill_body,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -878,7 +950,7 @@ impl PDRouter {
|
|||||||
.build_post_with_headers(
|
.build_post_with_headers(
|
||||||
&self.prefill_client,
|
&self.prefill_client,
|
||||||
prefill.url(),
|
prefill.url(),
|
||||||
route,
|
context.route,
|
||||||
&json_request,
|
&json_request,
|
||||||
headers,
|
headers,
|
||||||
true,
|
true,
|
||||||
@@ -886,11 +958,41 @@ impl PDRouter {
|
|||||||
.send();
|
.send();
|
||||||
let decode_future = decode_request.send();
|
let decode_future = decode_request.send();
|
||||||
|
|
||||||
|
// Send prefill response to background worker for draining
|
||||||
|
// This ensures HTTP compliance without blocking
|
||||||
|
let drain_tx = self.prefill_drain_tx.clone();
|
||||||
|
let prefill_url = prefill.url().to_string();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Ok(response) = prefill_future.await {
|
if let Ok(response) = prefill_future.await {
|
||||||
// Consume the entire response body to maintain HTTP compliance
|
// Try to send to drain worker
|
||||||
// This runs in the background and won't block the decode response
|
// If channel is full (under extreme load), drain inline as fallback
|
||||||
let _ = response.bytes().await;
|
match drain_tx.try_send(response) {
|
||||||
|
Ok(_) => {
|
||||||
|
// Successfully queued for draining
|
||||||
|
debug!("Prefill response queued for draining");
|
||||||
|
}
|
||||||
|
Err(mpsc::error::TrySendError::Full(response)) => {
|
||||||
|
// Channel full - drain inline as fallback
|
||||||
|
warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url);
|
||||||
|
RouterMetrics::record_pd_prefill_error(&prefill_url);
|
||||||
|
|
||||||
|
// Drain inline with timeout to prevent blocking too long
|
||||||
|
let drain_future = async {
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
while stream.next().await.is_some() {
|
||||||
|
// Just drain
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match tokio::time::timeout(Duration::from_secs(1), drain_future).await {
|
||||||
|
Ok(_) => debug!("Inline drain completed for {}", prefill_url),
|
||||||
|
Err(_) => error!("Inline drain timeout for {}", prefill_url),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||||
|
error!("Prefill drain channel closed!");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -900,8 +1002,8 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Update metrics
|
// Update metrics
|
||||||
let duration = start_time.elapsed();
|
let duration = start_time.elapsed();
|
||||||
RouterMetrics::record_pd_request_duration(route, duration);
|
RouterMetrics::record_pd_request_duration(context.route, duration);
|
||||||
RouterMetrics::record_pd_request(route);
|
RouterMetrics::record_pd_request(context.route);
|
||||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||||
RouterMetrics::record_pd_decode_request(decode.url());
|
RouterMetrics::record_pd_decode_request(decode.url());
|
||||||
|
|
||||||
@@ -928,7 +1030,7 @@ impl PDRouter {
|
|||||||
(status, format!("Decode server error: {}", e)).into_response()
|
(status, format!("Decode server error: {}", e)).into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if is_stream {
|
} else if context.is_stream {
|
||||||
// Streaming response without logprobs - direct passthrough
|
// Streaming response without logprobs - direct passthrough
|
||||||
let decode_url = decode.url().to_string();
|
let decode_url = decode.url().to_string();
|
||||||
let response_headers =
|
let response_headers =
|
||||||
@@ -1280,10 +1382,10 @@ impl PDRouter {
|
|||||||
|
|
||||||
fn build_post_with_headers(
|
fn build_post_with_headers(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &Client,
|
||||||
url: &str,
|
url: &str,
|
||||||
route: &str,
|
route: &str,
|
||||||
json_request: &serde_json::Value,
|
json_request: &Value,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
connection_close: bool,
|
connection_close: bool,
|
||||||
) -> reqwest::RequestBuilder {
|
) -> reqwest::RequestBuilder {
|
||||||
@@ -1894,6 +1996,7 @@ mod tests {
|
|||||||
load_monitor_handle: None,
|
load_monitor_handle: None,
|
||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
prefill_client: Client::new(),
|
prefill_client: Client::new(),
|
||||||
|
prefill_drain_tx: mpsc::channel(100).0,
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||||
_prefill_health_checker: None,
|
_prefill_health_checker: None,
|
||||||
|
|||||||
Reference in New Issue
Block a user