[router] address worker load tracking consistency (#9523)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -55,6 +55,12 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
|||||||
/// Decrement the load counter
|
/// Decrement the load counter
|
||||||
fn decrement_load(&self);
|
fn decrement_load(&self);
|
||||||
|
|
||||||
|
/// Reset the load counter to 0 (for sync/recovery)
|
||||||
|
fn reset_load(&self) {
|
||||||
|
// Default implementation - does nothing
|
||||||
|
// Workers that track load should override this
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the number of processed requests
|
/// Get the number of processed requests
|
||||||
fn processed_requests(&self) -> usize;
|
fn processed_requests(&self) -> usize;
|
||||||
|
|
||||||
@@ -364,6 +370,10 @@ impl Worker for BasicWorker {
|
|||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_load(&self) {
|
||||||
|
self.load_counter.store(0, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
fn processed_requests(&self) -> usize {
|
fn processed_requests(&self) -> usize {
|
||||||
self.processed_counter.load(Ordering::Relaxed)
|
self.processed_counter.load(Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
@@ -449,6 +459,10 @@ impl Worker for DPAwareWorker {
|
|||||||
self.base_worker.decrement_load();
|
self.base_worker.decrement_load();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_load(&self) {
|
||||||
|
self.base_worker.reset_load();
|
||||||
|
}
|
||||||
|
|
||||||
fn processed_requests(&self) -> usize {
|
fn processed_requests(&self) -> usize {
|
||||||
self.base_worker.processed_requests()
|
self.base_worker.processed_requests()
|
||||||
}
|
}
|
||||||
@@ -825,6 +839,10 @@ pub fn start_health_checker(
|
|||||||
let mut interval =
|
let mut interval =
|
||||||
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
||||||
|
|
||||||
|
// Counter for periodic load reset (every 10 health check cycles)
|
||||||
|
let mut check_count = 0u64;
|
||||||
|
const LOAD_RESET_INTERVAL: u64 = 10;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
interval.tick().await;
|
interval.tick().await;
|
||||||
|
|
||||||
@@ -834,6 +852,8 @@ pub fn start_health_checker(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
check_count += 1;
|
||||||
|
|
||||||
// Check health of all workers
|
// Check health of all workers
|
||||||
let workers_to_check = match workers.read() {
|
let workers_to_check = match workers.read() {
|
||||||
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
|
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
|
||||||
@@ -843,6 +863,22 @@ pub fn start_health_checker(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Periodically reset load counters to prevent drift
|
||||||
|
// Only do this when we believe all workers should be idle
|
||||||
|
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
|
||||||
|
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
|
||||||
|
// Only reset if load appears to be very low (likely drift)
|
||||||
|
if max_load <= 2 {
|
||||||
|
tracing::debug!(
|
||||||
|
"Resetting load counters to prevent drift (max_load: {})",
|
||||||
|
max_load
|
||||||
|
);
|
||||||
|
for worker in &workers_to_check {
|
||||||
|
worker.reset_load();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Perform health checks concurrently
|
// Perform health checks concurrently
|
||||||
let health_checks = workers_to_check.iter().map(|worker| {
|
let health_checks = workers_to_check.iter().map(|worker| {
|
||||||
let worker_url = worker.url().to_string();
|
let worker_url = worker.url().to_string();
|
||||||
|
|||||||
@@ -1243,10 +1243,19 @@ impl PDRouter {
|
|||||||
let decode_workers = self.decode_workers.clone();
|
let decode_workers = self.decode_workers.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
// Use a flag to track whether stream completed successfully
|
||||||
|
let mut stream_completed = false;
|
||||||
|
|
||||||
futures_util::pin_mut!(stream);
|
futures_util::pin_mut!(stream);
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
|
// Check for stream end marker to decrement load early
|
||||||
|
let is_done = chunk
|
||||||
|
.as_ref()
|
||||||
|
.windows(12)
|
||||||
|
.any(|window| window == b"data: [DONE]");
|
||||||
|
|
||||||
let result = if return_logprob && prefill_logprobs.is_some() {
|
let result = if return_logprob && prefill_logprobs.is_some() {
|
||||||
// Try to merge logprobs
|
// Try to merge logprobs
|
||||||
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
|
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
|
||||||
@@ -1258,6 +1267,12 @@ impl PDRouter {
|
|||||||
if tx.send(Ok(result)).is_err() {
|
if tx.send(Ok(result)).is_err() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we see the done marker, decrement load immediately
|
||||||
|
if is_done {
|
||||||
|
stream_completed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if let Some(ref url) = decode_url {
|
if let Some(ref url) = decode_url {
|
||||||
@@ -1270,20 +1285,30 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrement load after streaming is complete
|
// Always decrement load after streaming (either completes or errors)
|
||||||
|
// Find and decrement prefill worker
|
||||||
if let Ok(prefill_workers_guard) = prefill_workers.read() {
|
if let Ok(prefill_workers_guard) = prefill_workers.read() {
|
||||||
for worker in prefill_workers_guard.iter() {
|
for worker in prefill_workers_guard.iter() {
|
||||||
if worker.url() == prefill_url.as_str() {
|
if worker.url() == prefill_url.as_str() {
|
||||||
worker.decrement_load();
|
worker.decrement_load();
|
||||||
|
debug!(
|
||||||
|
"Decremented load for prefill worker: {} (stream_completed: {})",
|
||||||
|
prefill_url, stream_completed
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find and decrement decode worker
|
||||||
if let Ok(decode_workers_guard) = decode_workers.read() {
|
if let Ok(decode_workers_guard) = decode_workers.read() {
|
||||||
for worker in decode_workers_guard.iter() {
|
for worker in decode_workers_guard.iter() {
|
||||||
if worker.url() == decode_url_str.as_str() {
|
if worker.url() == decode_url_str.as_str() {
|
||||||
worker.decrement_load();
|
worker.decrement_load();
|
||||||
|
debug!(
|
||||||
|
"Decremented load for decode worker: {} (stream_completed: {})",
|
||||||
|
decode_url_str, stream_completed
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -490,6 +490,13 @@ impl Router {
|
|||||||
false
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Keep a clone for potential cleanup on retry
|
||||||
|
let worker_for_cleanup = if load_incremented {
|
||||||
|
Some(worker.clone_worker())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.send_typed_request(
|
.send_typed_request(
|
||||||
headers,
|
headers,
|
||||||
@@ -502,6 +509,19 @@ impl Router {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
worker.record_outcome(response.status().is_success());
|
worker.record_outcome(response.status().is_success());
|
||||||
|
|
||||||
|
// For retryable failures, we need to decrement load since send_typed_request
|
||||||
|
// won't have done it (it only decrements on success or non-retryable failures)
|
||||||
|
if is_retryable_status(response.status()) && load_incremented {
|
||||||
|
if let Some(cleanup_worker) = worker_for_cleanup {
|
||||||
|
cleanup_worker.decrement_load();
|
||||||
|
RouterMetrics::set_running_requests(
|
||||||
|
cleanup_worker.url(),
|
||||||
|
cleanup_worker.load(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response
|
response
|
||||||
},
|
},
|
||||||
// should_retry predicate
|
// should_retry predicate
|
||||||
@@ -657,13 +677,25 @@ impl Router {
|
|||||||
response
|
response
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
// IMPORTANT: Decrement load on error before returning
|
||||||
|
if load_incremented {
|
||||||
|
if let Ok(workers_guard) = self.workers.read() {
|
||||||
|
if let Some(worker) =
|
||||||
|
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||||
|
{
|
||||||
|
worker.decrement_load();
|
||||||
|
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let error_msg = format!("Failed to get response body: {}", e);
|
let error_msg = format!("Failed to get response body: {}", e);
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Decrement load counter for non-streaming requests if it was incremented
|
// Decrement load counter for non-streaming requests if it was incremented
|
||||||
if load_incremented && !is_stream {
|
if load_incremented {
|
||||||
if let Ok(workers_guard) = self.workers.read() {
|
if let Ok(workers_guard) = self.workers.read() {
|
||||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||||
worker.decrement_load();
|
worker.decrement_load();
|
||||||
|
|||||||
Reference in New Issue
Block a user