[router] consolidate worker load monitoring (#10894)

This commit is contained in:
Simo Lin
2025-09-25 09:59:30 -04:00
committed by GitHub
parent 77830a265e
commit d511b2d905
7 changed files with 199 additions and 232 deletions

View File

@@ -1,8 +1,7 @@
use super::pd_types::api_path;
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager,
WorkerRegistry, WorkerType,
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
@@ -23,18 +22,15 @@ use futures_util::StreamExt;
use reqwest::Client;
use serde::Serialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, warn};
#[derive(Debug)]
pub struct PDRouter {
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub client: Client,
pub retry_config: RetryConfig,
pub api_key: Option<String>,
@@ -124,71 +120,9 @@ impl PDRouter {
}
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
let prefill_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(ConnectionMode::Http),
false, // include all workers
);
let decode_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(ConnectionMode::Http),
false, // include all workers
);
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.url().to_string())
.collect();
let all_api_keys: Vec<Option<String>> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.api_key().clone())
.collect();
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
let prefill_policy = ctx.policy_registry.get_prefill_policy();
let decode_policy = ctx.policy_registry.get_decode_policy();
let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
let monitor_api_keys = all_api_keys.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let monitor_client = ctx.client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy);
let decode_policy_clone = Arc::clone(&decode_policy);
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client(
monitor_urls,
monitor_api_keys,
tx,
monitor_interval,
monitor_client,
prefill_policy_clone,
decode_policy_clone,
)
.await;
})))
} else {
None
};
// No longer need prefill drain channel - we'll wait for both responses
Ok(PDRouter {
worker_registry: Arc::clone(&ctx.worker_registry),
policy_registry: Arc::clone(&ctx.policy_registry),
worker_loads,
load_monitor_handle,
client: ctx.client.clone(),
retry_config: ctx.router_config.effective_retry_config(),
api_key: ctx.router_config.api_key.clone(),
@@ -708,55 +642,6 @@ impl PDRouter {
Ok(available_workers[selected_idx].clone())
}
async fn monitor_worker_loads_with_client(
worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
client: Client,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
) {
loop {
let mut loads = HashMap::new();
let futures: Vec<_> = worker_urls
.iter()
.zip(worker_api_keys.iter())
.map(|(url, api_key)| {
let client = client.clone();
let url = url.clone();
let api_key = api_key.clone();
async move {
let load =
WorkerManager::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(0);
(url, load)
}
})
.collect();
let results = futures_util::future::join_all(futures).await;
for (url, load) in results {
loads.insert(url, load);
}
debug!("Worker loads updated: {:?}", loads);
prefill_policy.update_loads(&loads);
decode_policy.update_loads(&loads);
if tx.send(loads).is_err() {
info!("Load monitor receiver dropped, shutting down monitor task");
break;
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
}
}
#[allow(clippy::too_many_arguments)]
fn create_streaming_response(
&self,
@@ -1375,8 +1260,6 @@ mod tests {
PDRouter {
worker_registry,
policy_registry,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
client: Client::new(),
retry_config: RetryConfig::default(),
api_key: Some("test_api_key".to_string()),
@@ -1436,31 +1319,6 @@ mod tests {
assert!(result.unwrap_err().contains("No prefill workers available"));
}
#[tokio::test]
async fn test_load_monitor_updates() {
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let mut router = create_test_pd_router();
router
.policy_registry
.set_prefill_policy(power_of_two_policy.clone());
router
.policy_registry
.set_decode_policy(power_of_two_policy);
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
router.worker_loads = Arc::new(rx);
let mut loads = HashMap::new();
loads.insert("http://worker1".to_string(), 10);
loads.insert("http://worker2".to_string(), 5);
let _ = tx.send(loads.clone());
let received = router.worker_loads.borrow().clone();
assert_eq!(received.get("http://worker1"), Some(&10));
assert_eq!(received.get("http://worker2"), Some(&5));
}
#[test]
fn test_worker_load_metrics() {
let prefill_worker = create_test_worker(

View File

@@ -1,10 +1,9 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
WorkerType,
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
@@ -23,9 +22,8 @@ use axum::{
};
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error};
@@ -38,8 +36,6 @@ pub struct Router {
dp_aware: bool,
enable_igw: bool,
retry_config: RetryConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}
impl Router {
@@ -54,42 +50,6 @@ impl Router {
RouterMetrics::set_active_workers(workers.len());
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_api_keys = monitor_urls
.iter()
.map(|url| {
ctx.worker_registry
.get_by_url(url)
.and_then(|w| w.api_key().clone())
})
.collect::<Vec<Option<String>>>();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone();
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads(
monitor_urls,
monitor_api_keys,
tx,
monitor_interval,
policy_clone,
client_clone,
)
.await;
})))
} else {
None
};
Ok(Router {
worker_registry: ctx.worker_registry.clone(),
policy_registry: ctx.policy_registry.clone(),
@@ -97,8 +57,6 @@ impl Router {
dp_aware: ctx.router_config.dp_aware,
enable_igw: ctx.router_config.enable_igw,
retry_config: ctx.router_config.effective_retry_config(),
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
})
}
@@ -661,42 +619,6 @@ impl Router {
}
}
// Background task to monitor worker loads
async fn monitor_worker_loads(
worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
) {
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
interval.tick().await;
let mut loads = HashMap::new();
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
// Use WorkerManager for consistent load fetching
if let Some(load) =
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
{
loads.insert(url.clone(), load);
}
}
if !loads.is_empty() {
// Update policy with new loads
policy.update_loads(&loads);
// Send to watchers
if let Err(e) = tx.send(loads) {
error!("Failed to send load update: {}", e);
}
}
}
}
async fn build_rerank_response(
req: &RerankRequest,
response: Response,
@@ -858,7 +780,6 @@ impl RouterTrait for Router {
mod tests {
use super::*;
use crate::core::BasicWorkerBuilder;
use std::collections::HashMap;
fn create_test_regular_router() -> Router {
// Create registries
@@ -877,15 +798,12 @@ mod tests {
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
Router {
worker_registry,
policy_registry,
dp_aware: false,
client: Client::new(),
retry_config: RetryConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
enable_igw: false,
}
}