diff --git a/.github/workflows/pr-test-pd-router.yml b/.github/workflows/pr-test-pd-router.yml index 28584f28a..45e210b68 100644 --- a/.github/workflows/pr-test-pd-router.yml +++ b/.github/workflows/pr-test-pd-router.yml @@ -219,6 +219,7 @@ jobs: --decode http://127.0.0.7:30007 \ --decode http://127.0.0.8:30008 \ --host 127.0.0.9 \ + --log-level warning \ --port 8000 & ROUTER_PID=$! @@ -300,8 +301,8 @@ jobs: --task text-to-text \ --num-concurrency 64 \ --traffic-scenario "D(8000,2000)" \ - --max-requests-per-run 640 \ - --max-time-per-run 2 \ + --max-requests-per-run 1000 \ + --max-time-per-run 5 \ --experiment-folder-name "benchmark_${policy}" \ --experiment-base-dir "." @@ -341,7 +342,7 @@ jobs: # These can be adjusted based on your performance requirements ttft_threshold=4.7 # Max 4.7 seconds for mean TTFT e2e_latency_threshold=35.0 # Max 35.0 seconds for mean E2E latency - input_throughput_threshold=12000 # Min 12000 tokens/s for mean input throughput + input_throughput_threshold=10000 # Min 02000 tokens/s for mean input throughput output_throughput_threshold=68 # Min 68 tokens/s for mean output throughput @@ -558,12 +559,12 @@ jobs: # Check thresholds (using same values as in main workflow) validation_status="✅" if [ "$ttft" != "N/A" ] && [ "$ttft" != "null" ]; then - if (( $(echo "$ttft > 2.0" | bc -l 2>/dev/null || echo "0") )); then + if (( $(echo "$ttft > 4.7" | bc -l 2>/dev/null || echo "0") )); then validation_status="❌" fi fi if [ "$e2e_latency" != "N/A" ] && [ "$e2e_latency" != "null" ]; then - if (( $(echo "$e2e_latency > 24.0" | bc -l 2>/dev/null || echo "0") )); then + if (( $(echo "$e2e_latency > 35.0" | bc -l 2>/dev/null || echo "0") )); then validation_status="❌" fi fi @@ -573,7 +574,7 @@ jobs: fi fi if [ "$output_throughput" != "N/A" ] && [ "$output_throughput" != "null" ]; then - if (( $(echo "$output_throughput < 90" | bc -l 2>/dev/null || echo "0") )); then + if (( $(echo "$output_throughput < 68" | bc -l 2>/dev/null || echo "0") )); then validation_status="❌" fi fi diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 4248ae060..23ff7ab57 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -27,7 +27,6 @@ use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; @@ -38,11 +37,9 @@ pub struct PDRouter { pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, pub client: Client, - pub prefill_client: Client, pub retry_config: RetryConfig, pub api_key: Option, pub enable_igw: bool, - prefill_drain_tx: mpsc::Sender, } #[derive(Clone)] @@ -241,72 +238,7 @@ impl PDRouter { None }; - let prefill_client = Client::builder() - .pool_max_idle_per_host(0) - .http1_only() - .connect_timeout(Duration::from_millis(300)) - .timeout(Duration::from_secs(ctx.router_config.request_timeout_secs)) - .build() - .map_err(|e| format!("Failed to build prefill client: {}", e))?; - - let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::(2000); - - // TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget) - tokio::spawn(async move { - info!("Prefill drain coordinator started"); - - let max_concurrent_drains = 100; - let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains)); - - while let Some(response) = prefill_drain_rx.recv().await { - let permit = semaphore.clone().acquire_owned().await; - - match permit { - Ok(permit) => { - tokio::spawn(async move { - let url = response.url().to_string(); - let status = response.status(); - - if !status.is_success() { - error!("Prefill drain: error status={} url={}", status, url); - RouterMetrics::record_pd_prefill_error(&url); - } - - let start = Instant::now(); - let mut stream = response.bytes_stream(); - let mut bytes_drained = 0; - - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => bytes_drained += chunk.len(), - Err(e) => { - debug!( - "Prefill drain: error streaming url={} error={}", - url, e - ); - break; - } - } - } - - let elapsed = start.elapsed(); - if elapsed > Duration::from_millis(100) { - debug!( - "Prefill drain: slow drain {} bytes from {} in {:?}", - bytes_drained, url, elapsed - ); - } - - drop(permit); - }); - } - Err(_) => { - break; - } - } - } - info!("Prefill drain coordinator shutting down"); - }); + // No longer need prefill drain channel - we'll wait for both responses Ok(PDRouter { worker_registry: Arc::clone(&ctx.worker_registry), @@ -314,8 +246,6 @@ impl PDRouter { worker_loads, load_monitor_handle, client: ctx.client.clone(), - prefill_client, - prefill_drain_tx, retry_config: ctx.router_config.effective_retry_config(), api_key: ctx.router_config.api_key.clone(), enable_igw: ctx.router_config.enable_igw, @@ -585,7 +515,15 @@ impl PDRouter { None }; - // Build decode request with shared client + // Build both requests + let prefill_request = self.build_post_with_headers( + &self.client, + prefill.url(), + context.route, + &json_request, + headers, + false, + ); let decode_request = self.build_post_with_headers( &self.client, decode.url(), @@ -595,57 +533,46 @@ impl PDRouter { false, ); - // Send both requests concurrently + // Send both requests concurrently and wait for both debug!( "Sending concurrent requests to prefill={} decode={}", prefill.url(), decode.url() ); - if context.return_logprob { - // Build prefill request with shared client when we need response body - let prefill_request = self.build_post_with_headers( - &self.client, - prefill.url(), - context.route, - &json_request, - headers, - false, - ); - // When we need logprobs, wait for both responses - let (prefill_result, decode_result) = - tokio::join!(prefill_request.send(), decode_request.send()); - debug!("Received responses from both servers"); + let (prefill_result, decode_result) = + tokio::join!(prefill_request.send(), decode_request.send()); + debug!("Received responses from both servers"); - let duration = start_time.elapsed(); - RouterMetrics::record_pd_request_duration(context.route, duration); - RouterMetrics::record_pd_request(context.route); - RouterMetrics::record_pd_prefill_request(prefill.url()); - RouterMetrics::record_pd_decode_request(decode.url()); + let duration = start_time.elapsed(); + RouterMetrics::record_pd_request_duration(context.route, duration); + RouterMetrics::record_pd_request(context.route); + RouterMetrics::record_pd_prefill_request(prefill.url()); + RouterMetrics::record_pd_decode_request(decode.url()); - // Process decode response with prefill for logprobs - debug!("Processing decode response with logprobs"); - match decode_result { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - debug!("Decode response status: {}", status); + // Process decode response + match decode_result { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + debug!("Decode response status: {}", status); - if !status.is_success() { - RouterMetrics::record_pd_decode_error(decode.url()); - error!( - "Decode server returned error status decode_url={} status={}", - decode.url(), - status - ); + if !status.is_success() { + RouterMetrics::record_pd_decode_error(decode.url()); + error!( + "Decode server returned error status decode_url={} status={}", + decode.url(), + status + ); - return self - .handle_decode_error_response(res, &context, prefill, decode) - .await; - } + return self + .handle_decode_error_response(res, &context, prefill, decode) + .await; + } - // Process prefill response for logprobs - let prefill_body = match self + // Process prefill response + let prefill_body = if context.return_logprob { + match self .process_prefill_response( prefill_result, prefill.url(), @@ -655,32 +582,46 @@ impl PDRouter { { Ok((_, body)) => body, Err(error_response) => return error_response, - }; + } + } else { + // Even if we don't need logprobs, we should check prefill status + match self + .process_prefill_response(prefill_result, prefill.url(), false) + .await + { + Ok((_, body)) => body, + Err(error_response) => return error_response, + } + }; - if context.is_stream { - // Streaming response with logprobs - let prefill_logprobs = prefill_body + if context.is_stream { + // Streaming response + let prefill_logprobs = if context.return_logprob { + prefill_body .as_ref() .and_then(|body| serde_json::from_slice::(body).ok()) .and_then(|json| { json.pointer("/meta_info/input_token_logprobs").cloned() - }); - - let response_headers = - header_utils::preserve_response_headers(res.headers()); - - self.create_streaming_response( - res.bytes_stream(), - status, - prefill_logprobs, - context.return_logprob, - None, - Some(response_headers), - prefill, - decode, - ) + }) } else { - // Non-streaming response with logprobs + None + }; + + let response_headers = header_utils::preserve_response_headers(res.headers()); + + self.create_streaming_response( + res.bytes_stream(), + status, + prefill_logprobs, + context.return_logprob, + None, + Some(response_headers), + prefill, + decode, + ) + } else { + // Non-streaming response + if context.return_logprob { self.process_non_streaming_response( res, status, @@ -688,122 +629,8 @@ impl PDRouter { prefill_body, ) .await - } - } - Err(e) => { - error!( - decode_url = %decode.url(), - error = %e, - "Decode request failed" - ); - RouterMetrics::record_pd_decode_error(decode.url()); - ( - StatusCode::BAD_GATEWAY, - format!("Decode server error: {}", e), - ) - .into_response() - } - } - } else { - // When we don't need logprobs, only wait for decode response - // Send both requests concurrently but don't wait for prefill - // Use dedicated prefill client with Connection: close - let prefill_future = self - .build_post_with_headers( - &self.prefill_client, - prefill.url(), - context.route, - &json_request, - headers, - true, - ) - .send(); - let decode_future = decode_request.send(); - - // Send prefill response to background worker for draining - // This ensures HTTP compliance without blocking - let drain_tx = self.prefill_drain_tx.clone(); - let prefill_url = prefill.url().to_string(); - tokio::spawn(async move { - if let Ok(response) = prefill_future.await { - // Try to send to drain worker - // If channel is full (under extreme load), drain inline as fallback - match drain_tx.try_send(response) { - Ok(_) => { - // Successfully queued for draining - debug!("Prefill response queued for draining"); - } - Err(mpsc::error::TrySendError::Full(response)) => { - // Channel full - drain inline as fallback - warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url); - RouterMetrics::record_pd_prefill_error(&prefill_url); - - // Drain inline with timeout to prevent blocking too long - let drain_future = async { - let mut stream = response.bytes_stream(); - while stream.next().await.is_some() { - // Just drain - } - }; - - match tokio::time::timeout(Duration::from_secs(1), drain_future).await { - Ok(_) => debug!("Inline drain completed for {}", prefill_url), - Err(_) => error!("Inline drain timeout for {}", prefill_url), - } - } - Err(mpsc::error::TrySendError::Closed(_)) => { - error!("Prefill drain channel closed!"); - } - } - } - }); - - // Wait only for decode response - let decode_result = decode_future.await; - debug!("Received decode response"); - - let duration = start_time.elapsed(); - RouterMetrics::record_pd_request_duration(context.route, duration); - RouterMetrics::record_pd_request(context.route); - RouterMetrics::record_pd_prefill_request(prefill.url()); - RouterMetrics::record_pd_decode_request(decode.url()); - - // Process decode response immediately - debug!("Processing decode response (no logprobs)"); - match decode_result { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - debug!("Decode response status: {}", status); - - if !status.is_success() { - RouterMetrics::record_pd_decode_error(decode.url()); - error!( - "Decode server returned error status decode_url={} status={}", - decode.url(), - status - ); - - self.handle_decode_error_response(res, &context, prefill, decode) - .await - } else if context.is_stream { - // Streaming response without logprobs - direct passthrough - let decode_url = decode.url().to_string(); - let response_headers = - header_utils::preserve_response_headers(res.headers()); - - self.create_streaming_response( - res.bytes_stream(), - status, - None, - false, - Some(decode_url), - Some(response_headers), - prefill, - decode, - ) } else { - // Non-streaming response without logprobs - direct passthrough like fast version + // Direct passthrough when no logprobs needed let response_headers = header_utils::preserve_response_headers(res.headers()); @@ -823,19 +650,19 @@ impl PDRouter { } } } - Err(e) => { - error!( - decode_url = %decode.url(), - error = %e, - "Decode request failed" - ); - RouterMetrics::record_pd_decode_error(decode.url()); - ( - StatusCode::BAD_GATEWAY, - format!("Decode server error: {}", e), - ) - .into_response() - } + } + Err(e) => { + error!( + decode_url = %decode.url(), + error = %e, + "Decode request failed" + ); + RouterMetrics::record_pd_decode_error(decode.url()); + ( + StatusCode::BAD_GATEWAY, + format!("Decode server error: {}", e), + ) + .into_response() } } } @@ -1802,8 +1629,6 @@ mod tests { worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), load_monitor_handle: None, client: Client::new(), - prefill_client: Client::new(), - prefill_drain_tx: mpsc::channel(100).0, retry_config: RetryConfig::default(), api_key: Some("test_api_key".to_string()), enable_igw: false,