From 0f8cee8cd3732de680fda61259cc0289ca551f0b Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 21 Aug 2025 22:48:29 -0700 Subject: [PATCH] [router] fix router load guard tracking for streaming (#9491) --- sgl-router/src/routers/pd_router.rs | 128 +++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 4 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index cba55c5cd..a3e749f93 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -821,8 +821,13 @@ impl PDRouter { decode: &dyn Worker, start_time: Instant, ) -> Response { - // Update load tracking for both workers - let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); + // For non-streaming: use guard for automatic load management + // For streaming: load will be managed in create_streaming_response + let _guard = if !context.is_stream { + Some(WorkerLoadGuard::new_multi(vec![prefill, decode])) + } else { + None + }; // Build decode request with shared client let decode_request = self.build_post_with_headers( @@ -916,13 +921,15 @@ impl PDRouter { let response_headers = header_utils::preserve_response_headers(res.headers()); - Self::create_streaming_response( + self.create_streaming_response( res.bytes_stream(), status, prefill_logprobs, context.return_logprob, None, Some(response_headers), + prefill, + decode, ) } else { // Non-streaming response with logprobs @@ -1043,13 +1050,15 @@ impl PDRouter { let response_headers = header_utils::preserve_response_headers(res.headers()); - Self::create_streaming_response( + self.create_streaming_response( res.bytes_stream(), status, None, false, Some(decode_url), Some(response_headers), + prefill, + decode, ) } else { // Non-streaming response without logprobs - direct passthrough like fast version @@ -1210,16 +1219,32 @@ impl PDRouter { } // Helper to create a streaming response + #[allow(clippy::too_many_arguments)] fn create_streaming_response( + &self, stream: impl futures_util::Stream> + Send + 'static, status: StatusCode, prefill_logprobs: Option, return_logprob: bool, decode_url: Option, headers: Option, + prefill: &dyn Worker, + decode: &dyn Worker, ) -> Response { + // For streaming, increment load now - will be decremented when streaming completes + prefill.increment_load(); + decode.increment_load(); + + // Store URLs to find workers later for decrementing + let prefill_url = prefill.url().to_string(); + let decode_url_str = decode.url().to_string(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + // Clone the worker collections for the spawned task + let prefill_workers = self.prefill_workers.clone(); + let decode_workers = self.decode_workers.clone(); + tokio::spawn(async move { futures_util::pin_mut!(stream); while let Some(chunk_result) = stream.next().await { @@ -1247,6 +1272,25 @@ impl PDRouter { } } } + + // Decrement load after streaming is complete + if let Ok(prefill_workers_guard) = prefill_workers.read() { + for worker in prefill_workers_guard.iter() { + if worker.url() == prefill_url.as_str() { + worker.decrement_load(); + break; + } + } + } + + if let Ok(decode_workers_guard) = decode_workers.read() { + for worker in decode_workers_guard.iter() { + if worker.url() == decode_url_str.as_str() { + worker.decrement_load(); + break; + } + } + } }); let stream = UnboundedReceiverStream::new(rx); @@ -2279,6 +2323,82 @@ mod tests { assert_eq!(decode_worker.load(), 0); } + #[tokio::test] + async fn test_streaming_load_tracking() { + use futures_util::StreamExt; + use tokio::time::{sleep, Duration}; + + let router = create_test_pd_router(); + + // Add workers + let prefill_worker = create_test_worker( + "http://prefill".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let decode_worker = + create_test_worker("http://decode".to_string(), WorkerType::Decode, true); + + router.prefill_workers.write().unwrap().push(prefill_worker); + router.decode_workers.write().unwrap().push(decode_worker); + + // Get references to the workers - clone to avoid holding lock across await + let (prefill_ref, decode_ref) = { + let workers = router.prefill_workers.read().unwrap(); + let prefill = workers[0].clone_worker(); + drop(workers); + let workers = router.decode_workers.read().unwrap(); + let decode = workers[0].clone_worker(); + (prefill, decode) + }; + + // Initially load should be 0 + assert_eq!(prefill_ref.load(), 0); + assert_eq!(decode_ref.load(), 0); + + // Create a mock streaming response + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + + // Call create_streaming_response which should increment load + let _response = router.create_streaming_response( + stream.map(Ok), + StatusCode::OK, + None, + false, + None, + None, + prefill_ref.as_ref(), + decode_ref.as_ref(), + ); + + // Load should be incremented immediately + assert_eq!(prefill_ref.load(), 1); + assert_eq!(decode_ref.load(), 1); + + // Send some data through the stream + tx.send(bytes::Bytes::from("test data")).unwrap(); + + // Give time for the spawned task to process + sleep(Duration::from_millis(10)).await; + + // Load should still be 1 (streaming in progress) + assert_eq!(prefill_ref.load(), 1); + assert_eq!(decode_ref.load(), 1); + + // Close the stream + drop(tx); + + // Give time for cleanup + sleep(Duration::from_millis(100)).await; + + // Load should be decremented after streaming completes + assert_eq!(prefill_ref.load(), 0); + assert_eq!(decode_ref.load(), 0); + } + // ============= Concurrent Operations Tests ============= #[tokio::test]