From d511b2d905177ebc4040186ffb1493414e190a80 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 25 Sep 2025 09:59:30 -0400 Subject: [PATCH] [router] consolidate worker load monitoring (#10894) --- sgl-router/src/core/mod.rs | 2 +- sgl-router/src/core/worker_manager.rs | 135 +++++++++++++++++++++ sgl-router/src/policies/registry.rs | 41 +++++++ sgl-router/src/routers/http/pd_router.rs | 148 +---------------------- sgl-router/src/routers/http/router.rs | 88 +------------- sgl-router/src/server.rs | 16 ++- sgl-router/src/service_discovery.rs | 1 + 7 files changed, 199 insertions(+), 232 deletions(-) diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index b3f5bbcbe..a3aa24ba0 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -25,5 +25,5 @@ pub use worker::{ Worker, WorkerFactory, WorkerLoadGuard, WorkerType, }; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; -pub use worker_manager::{DpInfo, ServerInfo, WorkerManager}; +pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index 9dde2aca8..0fc8e8b0e 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -23,6 +23,8 @@ use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use tokio::sync::{watch, Mutex}; +use tokio::task::JoinHandle; use tracing::{debug, error, info, warn}; static HTTP_CLIENT: Lazy = Lazy::new(|| { @@ -1177,6 +1179,139 @@ impl WorkerManager { } } +/// Load monitoring service that periodically fetches worker loads +pub struct LoadMonitor { + worker_registry: Arc, + policy_registry: Arc, + client: reqwest::Client, + interval: Duration, + tx: watch::Sender>, + rx: watch::Receiver>, + monitor_handle: Arc>>>, +} + +impl LoadMonitor { + /// Create a new load monitor + pub fn new( + worker_registry: Arc, + policy_registry: Arc, + client: reqwest::Client, + interval_secs: u64, + ) -> Self { + let (tx, rx) = watch::channel(HashMap::new()); + + Self { + worker_registry, + policy_registry, + client, + interval: Duration::from_secs(interval_secs), + tx, + rx, + monitor_handle: Arc::new(Mutex::new(None)), + } + } + + /// Start monitoring worker loads + pub async fn start(&self) { + let mut handle_guard = self.monitor_handle.lock().await; + if handle_guard.is_some() { + debug!("Load monitoring already running"); + return; + } + + info!( + "Starting load monitoring with interval: {:?}", + self.interval + ); + + let worker_registry = Arc::clone(&self.worker_registry); + let policy_registry = Arc::clone(&self.policy_registry); + let client = self.client.clone(); + let interval = self.interval; + let tx = self.tx.clone(); + + let handle = tokio::spawn(async move { + Self::monitor_loop(worker_registry, policy_registry, client, interval, tx).await; + }); + + *handle_guard = Some(handle); + } + + /// Stop monitoring worker loads + pub async fn stop(&self) { + let mut handle_guard = self.monitor_handle.lock().await; + if let Some(handle) = handle_guard.take() { + info!("Stopping load monitoring"); + handle.abort(); + let _ = handle.await; // Wait for task to finish + } + } + + /// Get a receiver for load updates + pub fn subscribe(&self) -> watch::Receiver> { + self.rx.clone() + } + + /// The main monitoring loop + async fn monitor_loop( + worker_registry: Arc, + policy_registry: Arc, + client: reqwest::Client, + interval: Duration, + tx: watch::Sender>, + ) { + let mut interval_timer = tokio::time::interval(interval); + + loop { + interval_timer.tick().await; + + let power_of_two_policies = policy_registry.get_all_power_of_two_policies(); + + if power_of_two_policies.is_empty() { + debug!("No PowerOfTwo policies found, skipping load fetch"); + continue; + } + + let result = WorkerManager::get_all_worker_loads(&worker_registry, &client).await; + + let mut loads = HashMap::new(); + for load_info in result.loads { + loads.insert(load_info.worker, load_info.load); + } + + if !loads.is_empty() { + debug!( + "Fetched loads from {} workers, updating {} PowerOfTwo policies", + loads.len(), + power_of_two_policies.len() + ); + for policy in &power_of_two_policies { + policy.update_loads(&loads); + } + let _ = tx.send(loads); + } else { + warn!("No loads fetched from workers"); + } + } + } + + /// Check if monitoring is currently active + pub async fn is_running(&self) -> bool { + let handle_guard = self.monitor_handle.lock().await; + handle_guard.is_some() + } +} + +impl Drop for LoadMonitor { + fn drop(&mut self) { + if let Ok(mut handle_guard) = self.monitor_handle.try_lock() { + if let Some(handle) = handle_guard.take() { + handle.abort(); + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/sgl-router/src/policies/registry.rs b/sgl-router/src/policies/registry.rs index 3abd812b5..8d9de51d3 100644 --- a/sgl-router/src/policies/registry.rs +++ b/sgl-router/src/policies/registry.rs @@ -257,6 +257,47 @@ impl PolicyRegistry { .unwrap_or_else(|| self.get_default_policy()) } + /// Get all PowerOfTwo policies that need load updates + pub fn get_all_power_of_two_policies(&self) -> Vec> { + let mut power_of_two_policies = Vec::new(); + + if self.default_policy.name() == "power_of_two" { + power_of_two_policies.push(Arc::clone(&self.default_policy)); + } + + if let Some(ref policy) = *self.prefill_policy.read().unwrap() { + if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) { + power_of_two_policies.push(Arc::clone(policy)); + } + } + + if let Some(ref policy) = *self.decode_policy.read().unwrap() { + if policy.name() == "power_of_two" + && !Arc::ptr_eq(policy, &self.default_policy) + && !self + .prefill_policy + .read() + .unwrap() + .as_ref() + .is_some_and(|p| Arc::ptr_eq(p, policy)) + { + power_of_two_policies.push(Arc::clone(policy)); + } + } + + let model_policies = self.model_policies.read().unwrap(); + for policy in model_policies.values() { + if policy.name() == "power_of_two" { + let already_added = power_of_two_policies.iter().any(|p| Arc::ptr_eq(p, policy)); + if !already_added { + power_of_two_policies.push(Arc::clone(policy)); + } + } + } + + power_of_two_policies + } + /// Initialize cache-aware policy with workers if applicable /// This should be called after workers are registered for a model pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc]) { diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index f655c7923..ab5e793d4 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1,8 +1,7 @@ use super::pd_types::api_path; use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager, - WorkerRegistry, WorkerType, + is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -23,18 +22,15 @@ use futures_util::StreamExt; use reqwest::Client; use serde::Serialize; use serde_json::{json, Value}; -use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, warn}; #[derive(Debug)] pub struct PDRouter { pub worker_registry: Arc, pub policy_registry: Arc, - pub worker_loads: Arc>>, - pub load_monitor_handle: Option>>, pub client: Client, pub retry_config: RetryConfig, pub api_key: Option, @@ -124,71 +120,9 @@ impl PDRouter { } pub async fn new(ctx: &Arc) -> Result { - let prefill_workers = ctx.worker_registry.get_workers_filtered( - None, // any model - Some(WorkerType::Prefill { - bootstrap_port: None, - }), - Some(ConnectionMode::Http), - false, // include all workers - ); - - let decode_workers = ctx.worker_registry.get_workers_filtered( - None, // any model - Some(WorkerType::Decode), - Some(ConnectionMode::Http), - false, // include all workers - ); - - let all_urls: Vec = prefill_workers - .iter() - .chain(decode_workers.iter()) - .map(|w| w.url().to_string()) - .collect(); - let all_api_keys: Vec> = prefill_workers - .iter() - .chain(decode_workers.iter()) - .map(|w| w.api_key().clone()) - .collect(); - - let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); - let worker_loads = Arc::new(rx); - - let prefill_policy = ctx.policy_registry.get_prefill_policy(); - let decode_policy = ctx.policy_registry.get_decode_policy(); - - let load_monitor_handle = - if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { - let monitor_urls = all_urls.clone(); - let monitor_api_keys = all_api_keys.clone(); - let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; - let monitor_client = ctx.client.clone(); - let prefill_policy_clone = Arc::clone(&prefill_policy); - let decode_policy_clone = Arc::clone(&decode_policy); - - Some(Arc::new(tokio::spawn(async move { - Self::monitor_worker_loads_with_client( - monitor_urls, - monitor_api_keys, - tx, - monitor_interval, - monitor_client, - prefill_policy_clone, - decode_policy_clone, - ) - .await; - }))) - } else { - None - }; - - // No longer need prefill drain channel - we'll wait for both responses - Ok(PDRouter { worker_registry: Arc::clone(&ctx.worker_registry), policy_registry: Arc::clone(&ctx.policy_registry), - worker_loads, - load_monitor_handle, client: ctx.client.clone(), retry_config: ctx.router_config.effective_retry_config(), api_key: ctx.router_config.api_key.clone(), @@ -708,55 +642,6 @@ impl PDRouter { Ok(available_workers[selected_idx].clone()) } - async fn monitor_worker_loads_with_client( - worker_urls: Vec, - worker_api_keys: Vec>, - tx: tokio::sync::watch::Sender>, - interval_secs: u64, - client: Client, - prefill_policy: Arc, - decode_policy: Arc, - ) { - loop { - let mut loads = HashMap::new(); - - let futures: Vec<_> = worker_urls - .iter() - .zip(worker_api_keys.iter()) - .map(|(url, api_key)| { - let client = client.clone(); - let url = url.clone(); - let api_key = api_key.clone(); - async move { - let load = - WorkerManager::get_worker_load(&url, api_key.as_deref(), &client) - .await - .unwrap_or(0); - (url, load) - } - }) - .collect(); - - let results = futures_util::future::join_all(futures).await; - - for (url, load) in results { - loads.insert(url, load); - } - - debug!("Worker loads updated: {:?}", loads); - - prefill_policy.update_loads(&loads); - decode_policy.update_loads(&loads); - - if tx.send(loads).is_err() { - info!("Load monitor receiver dropped, shutting down monitor task"); - break; - } - - tokio::time::sleep(Duration::from_secs(interval_secs)).await; - } - } - #[allow(clippy::too_many_arguments)] fn create_streaming_response( &self, @@ -1375,8 +1260,6 @@ mod tests { PDRouter { worker_registry, policy_registry, - worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), - load_monitor_handle: None, client: Client::new(), retry_config: RetryConfig::default(), api_key: Some("test_api_key".to_string()), @@ -1436,31 +1319,6 @@ mod tests { assert!(result.unwrap_err().contains("No prefill workers available")); } - #[tokio::test] - async fn test_load_monitor_updates() { - let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); - let mut router = create_test_pd_router(); - router - .policy_registry - .set_prefill_policy(power_of_two_policy.clone()); - router - .policy_registry - .set_decode_policy(power_of_two_policy); - - let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); - router.worker_loads = Arc::new(rx); - - let mut loads = HashMap::new(); - loads.insert("http://worker1".to_string(), 10); - loads.insert("http://worker2".to_string(), 5); - - let _ = tx.send(loads.clone()); - - let received = router.worker_loads.borrow().clone(); - assert_eq!(received.get("http://worker1"), Some(&10)); - assert_eq!(received.get("http://worker2"), Some(&5)); - } - #[test] fn test_worker_load_metrics() { let prefill_worker = create_test_worker( diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 8be3490fb..af642cb2e 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,10 +1,9 @@ use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry, - WorkerType, + is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; -use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; +use crate::policies::PolicyRegistry; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest, @@ -23,9 +22,8 @@ use axum::{ }; use futures_util::StreamExt; use reqwest::Client; -use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error}; @@ -38,8 +36,6 @@ pub struct Router { dp_aware: bool, enable_igw: bool, retry_config: RetryConfig, - _worker_loads: Arc>>, - _load_monitor_handle: Option>>, } impl Router { @@ -54,42 +50,6 @@ impl Router { RouterMetrics::set_active_workers(workers.len()); - let worker_urls: Vec = workers.iter().map(|w| w.url().to_string()).collect(); - - let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); - let worker_loads = Arc::new(rx); - - let default_policy = ctx.policy_registry.get_default_policy(); - - let load_monitor_handle = if default_policy.name() == "power_of_two" { - let monitor_urls = worker_urls.clone(); - let monitor_api_keys = monitor_urls - .iter() - .map(|url| { - ctx.worker_registry - .get_by_url(url) - .and_then(|w| w.api_key().clone()) - }) - .collect::>>(); - let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; - let policy_clone = default_policy.clone(); - let client_clone = ctx.client.clone(); - - Some(Arc::new(tokio::spawn(async move { - Self::monitor_worker_loads( - monitor_urls, - monitor_api_keys, - tx, - monitor_interval, - policy_clone, - client_clone, - ) - .await; - }))) - } else { - None - }; - Ok(Router { worker_registry: ctx.worker_registry.clone(), policy_registry: ctx.policy_registry.clone(), @@ -97,8 +57,6 @@ impl Router { dp_aware: ctx.router_config.dp_aware, enable_igw: ctx.router_config.enable_igw, retry_config: ctx.router_config.effective_retry_config(), - _worker_loads: worker_loads, - _load_monitor_handle: load_monitor_handle, }) } @@ -661,42 +619,6 @@ impl Router { } } - // Background task to monitor worker loads - async fn monitor_worker_loads( - worker_urls: Vec, - worker_api_keys: Vec>, - tx: tokio::sync::watch::Sender>, - interval_secs: u64, - policy: Arc, - client: Client, - ) { - let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); - - loop { - interval.tick().await; - - let mut loads = HashMap::new(); - for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) { - // Use WorkerManager for consistent load fetching - if let Some(load) = - WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await - { - loads.insert(url.clone(), load); - } - } - - if !loads.is_empty() { - // Update policy with new loads - policy.update_loads(&loads); - - // Send to watchers - if let Err(e) = tx.send(loads) { - error!("Failed to send load update: {}", e); - } - } - } - } - async fn build_rerank_response( req: &RerankRequest, response: Response, @@ -858,7 +780,6 @@ impl RouterTrait for Router { mod tests { use super::*; use crate::core::BasicWorkerBuilder; - use std::collections::HashMap; fn create_test_regular_router() -> Router { // Create registries @@ -877,15 +798,12 @@ mod tests { worker_registry.register(Arc::new(worker1)); worker_registry.register(Arc::new(worker2)); - let (_, rx) = tokio::sync::watch::channel(HashMap::new()); Router { worker_registry, policy_registry, dp_aware: false, client: Client::new(), retry_config: RetryConfig::default(), - _worker_loads: Arc::new(rx), - _load_monitor_handle: None, enable_igw: false, } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 7a7c0be65..c14e074bc 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,6 +1,6 @@ use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, - core::{WorkerManager, WorkerRegistry, WorkerType}, + core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType}, data_connector::{ MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage, }, @@ -51,6 +51,7 @@ pub struct AppContext { pub policy_registry: Arc, pub router_manager: Option>, pub response_storage: SharedResponseStorage, + pub load_monitor: Option>, } impl AppContext { @@ -107,6 +108,13 @@ impl AppContext { } }; + let load_monitor = Some(Arc::new(LoadMonitor::new( + worker_registry.clone(), + policy_registry.clone(), + client.clone(), + router_config.worker_startup_check_interval_secs, + ))); + Ok(Self { client, router_config, @@ -118,6 +126,7 @@ impl AppContext { policy_registry, router_manager, response_storage, + load_monitor, }) } } @@ -727,6 +736,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box