diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 2466d00b0..f3039ae21 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -55,6 +55,12 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Decrement the load counter fn decrement_load(&self); + /// Reset the load counter to 0 (for sync/recovery) + fn reset_load(&self) { + // Default implementation - does nothing + // Workers that track load should override this + } + /// Get the number of processed requests fn processed_requests(&self) -> usize; @@ -364,6 +370,10 @@ impl Worker for BasicWorker { .ok(); } + fn reset_load(&self) { + self.load_counter.store(0, Ordering::Relaxed); + } + fn processed_requests(&self) -> usize { self.processed_counter.load(Ordering::Relaxed) } @@ -449,6 +459,10 @@ impl Worker for DPAwareWorker { self.base_worker.decrement_load(); } + fn reset_load(&self) { + self.base_worker.reset_load(); + } + fn processed_requests(&self) -> usize { self.base_worker.processed_requests() } @@ -825,6 +839,10 @@ pub fn start_health_checker( let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs)); + // Counter for periodic load reset (every 10 health check cycles) + let mut check_count = 0u64; + const LOAD_RESET_INTERVAL: u64 = 10; + loop { interval.tick().await; @@ -834,6 +852,8 @@ pub fn start_health_checker( break; } + check_count += 1; + // Check health of all workers let workers_to_check = match workers.read() { Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::>(), @@ -843,6 +863,22 @@ pub fn start_health_checker( } }; + // Periodically reset load counters to prevent drift + // Only do this when we believe all workers should be idle + if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { + let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0); + // Only reset if load appears to be very low (likely drift) + if max_load <= 2 { + tracing::debug!( + "Resetting load counters to prevent drift (max_load: {})", + max_load + ); + for worker in &workers_to_check { + worker.reset_load(); + } + } + } + // Perform health checks concurrently let health_checks = workers_to_check.iter().map(|worker| { let worker_url = worker.url().to_string(); diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 3511582f0..42fd54598 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1243,10 +1243,19 @@ impl PDRouter { let decode_workers = self.decode_workers.clone(); tokio::spawn(async move { + // Use a flag to track whether stream completed successfully + let mut stream_completed = false; + futures_util::pin_mut!(stream); while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(chunk) => { + // Check for stream end marker to decrement load early + let is_done = chunk + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]"); + let result = if return_logprob && prefill_logprobs.is_some() { // Try to merge logprobs Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk) @@ -1258,6 +1267,12 @@ impl PDRouter { if tx.send(Ok(result)).is_err() { break; } + + // If we see the done marker, decrement load immediately + if is_done { + stream_completed = true; + break; + } } Err(e) => { if let Some(ref url) = decode_url { @@ -1270,20 +1285,30 @@ impl PDRouter { } } - // Decrement load after streaming is complete + // Always decrement load after streaming (either completes or errors) + // Find and decrement prefill worker if let Ok(prefill_workers_guard) = prefill_workers.read() { for worker in prefill_workers_guard.iter() { if worker.url() == prefill_url.as_str() { worker.decrement_load(); + debug!( + "Decremented load for prefill worker: {} (stream_completed: {})", + prefill_url, stream_completed + ); break; } } } + // Find and decrement decode worker if let Ok(decode_workers_guard) = decode_workers.read() { for worker in decode_workers_guard.iter() { if worker.url() == decode_url_str.as_str() { worker.decrement_load(); + debug!( + "Decremented load for decode worker: {} (stream_completed: {})", + decode_url_str, stream_completed + ); break; } } diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 00dbe32dc..077ad6d4f 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -490,6 +490,13 @@ impl Router { false }; + // Keep a clone for potential cleanup on retry + let worker_for_cleanup = if load_incremented { + Some(worker.clone_worker()) + } else { + None + }; + let response = self .send_typed_request( headers, @@ -502,6 +509,19 @@ impl Router { .await; worker.record_outcome(response.status().is_success()); + + // For retryable failures, we need to decrement load since send_typed_request + // won't have done it (it only decrements on success or non-retryable failures) + if is_retryable_status(response.status()) && load_incremented { + if let Some(cleanup_worker) = worker_for_cleanup { + cleanup_worker.decrement_load(); + RouterMetrics::set_running_requests( + cleanup_worker.url(), + cleanup_worker.load(), + ); + } + } + response }, // should_retry predicate @@ -657,13 +677,25 @@ impl Router { response } Err(e) => { + // IMPORTANT: Decrement load on error before returning + if load_incremented { + if let Ok(workers_guard) = self.workers.read() { + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == worker_url) + { + worker.decrement_load(); + RouterMetrics::set_running_requests(worker_url, worker.load()); + } + } + } + let error_msg = format!("Failed to get response body: {}", e); (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response() } }; // Decrement load counter for non-streaming requests if it was incremented - if load_incremented && !is_stream { + if load_incremented { if let Ok(workers_guard) = self.workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { worker.decrement_load();