From 1d1ce62495db61ac74a2edc24983561f48fb1338 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 19 Sep 2025 23:54:40 -0400 Subject: [PATCH] [router] refactor router and worker management 2.5/n (#10677) --- sgl-router/src/policies/registry.rs | 76 ++++++++++++ sgl-router/src/routers/http/pd_router.rs | 69 +++-------- sgl-router/src/routers/http/router.rs | 75 +++--------- sgl-router/src/routers/worker_initializer.rs | 120 ++++++++++++++++++- sgl-router/src/server.rs | 12 +- sgl-router/tests/api_endpoints_test.rs | 2 +- sgl-router/tests/request_formats_test.rs | 2 +- sgl-router/tests/streaming_tests.rs | 2 +- 8 files changed, 235 insertions(+), 123 deletions(-) diff --git a/sgl-router/src/policies/registry.rs b/sgl-router/src/policies/registry.rs index 326b29d76..3abd812b5 100644 --- a/sgl-router/src/policies/registry.rs +++ b/sgl-router/src/policies/registry.rs @@ -9,6 +9,7 @@ use super::{ RoundRobinPolicy, }; use crate::config::types::PolicyConfig; +use crate::core::Worker; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use tracing::{debug, info, warn}; @@ -255,6 +256,81 @@ impl PolicyRegistry { .map(Arc::clone) .unwrap_or_else(|| self.get_default_policy()) } + + /// Initialize cache-aware policy with workers if applicable + /// This should be called after workers are registered for a model + pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc]) { + // Get the policy for this model + if let Some(policy) = self.get_policy(model_id) { + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy.as_any().downcast_ref::() { + debug!( + "Initializing cache-aware policy with {} workers for model {}", + workers.len(), + model_id + ); + cache_aware.init_workers(workers); + } + } + } + } + + /// Remove a worker from cache-aware policy if applicable + /// This should be called when a worker is being removed + pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) { + // Get the policy for this model + if let Some(policy) = self.get_policy(model_id) { + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy.as_any().downcast_ref::() { + cache_aware.remove_worker_by_url(worker_url); + debug!( + "Removed worker {} from cache-aware policy for model {}", + worker_url, model_id + ); + } + } + } + } + + /// Initialize cache-aware policies for PD mode (prefill and decode) + pub fn init_pd_cache_aware_policies( + &self, + prefill_workers: &[Arc], + decode_workers: &[Arc], + ) { + // Initialize prefill policy if it's cache-aware + if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() { + if prefill_policy.name() == "cache_aware" { + if let Some(cache_aware) = + prefill_policy.as_any().downcast_ref::() + { + if !prefill_workers.is_empty() { + debug!( + "Initializing prefill cache-aware policy with {} workers", + prefill_workers.len() + ); + cache_aware.init_workers(prefill_workers); + } + } + } + } + + // Initialize decode policy if it's cache-aware + if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() { + if decode_policy.name() == "cache_aware" { + if let Some(cache_aware) = decode_policy.as_any().downcast_ref::() + { + if !decode_workers.is_empty() { + debug!( + "Initializing decode cache-aware policy with {} workers", + decode_workers.len() + ); + cache_aware.init_workers(decode_workers); + } + } + } + } + } } impl std::fmt::Debug for PolicyRegistry { diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index b0002f62e..45a48b4d1 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -232,18 +232,12 @@ impl PDRouter { // Notify PolicyRegistry about the new worker let model_id = worker_arc.model_id(); - let policy = self.policy_registry.on_worker_added(model_id, None); + self.policy_registry.on_worker_added(model_id, None); - // If this is a cache-aware policy, update it with all workers for this model - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - let model_workers = self.worker_registry.get_by_model_fast(model_id); - cache_aware.init_workers(&model_workers); - } - } + // Initialize cache-aware policy if applicable + let model_workers = self.worker_registry.get_by_model_fast(model_id); + self.policy_registry + .init_cache_aware_policy(model_id, &model_workers); info!("Added prefill server: {}", url); Ok(format!("Successfully added prefill server: {}", url)) @@ -272,18 +266,12 @@ impl PDRouter { // Notify PolicyRegistry about the new worker let model_id = worker_arc.model_id(); - let policy = self.policy_registry.on_worker_added(model_id, None); + self.policy_registry.on_worker_added(model_id, None); - // If this is a cache-aware policy, update it with all workers for this model - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - let model_workers = self.worker_registry.get_by_model_fast(model_id); - cache_aware.init_workers(&model_workers); - } - } + // Initialize cache-aware policy if applicable + let model_workers = self.worker_registry.get_by_model_fast(model_id); + self.policy_registry + .init_cache_aware_policy(model_id, &model_workers); info!("Added decode server: {}", url); Ok(format!("Successfully added decode server: {}", url)) @@ -307,17 +295,9 @@ impl PDRouter { // Notify PolicyRegistry about the removed worker self.policy_registry.on_worker_removed(&model_id); - // Get the policy for this model to update cache-aware if needed - if let Some(policy) = self.policy_registry.get_policy(&model_id) { - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - cache_aware.remove_worker_by_url(url); - } - } - } + // Remove from cache-aware policy if applicable + self.policy_registry + .remove_worker_from_cache_aware(&model_id, url); } if removed.is_some() { @@ -348,17 +328,9 @@ impl PDRouter { // Notify PolicyRegistry about the removed worker self.policy_registry.on_worker_removed(&model_id); - // Get the policy for this model to update cache-aware if needed - if let Some(policy) = self.policy_registry.get_policy(&model_id) { - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - cache_aware.remove_worker_by_url(url); - } - } - } + // Remove from cache-aware policy if applicable + self.policy_registry + .remove_worker_from_cache_aware(&model_id, url); } if removed.is_some() { @@ -2226,15 +2198,6 @@ mod tests { assert_eq!(prefill_workers.len(), 1); } - // ============= Bootstrap Injection Tests ============= - // Note: These tests are commented out as we've moved to the optimized bootstrap injection - // approach that doesn't use the Bootstrap trait on GenerateReqInput anymore. - - // TODO: Add new tests for the optimized bootstrap injection approach using - // RequestWithBootstrap and BatchRequestWithBootstrap wrappers - - // ============= Worker Selection Tests ============= - #[tokio::test] async fn test_select_healthy_prefill_worker() { let router = create_test_pd_router(); diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index e5305cb56..36fd42a7f 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -70,21 +70,15 @@ impl Router { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Initialize cache-aware policy with workers if needed - let default_policy = ctx.policy_registry.get_default_policy(); - if default_policy.name() == "cache_aware" { - if let Some(cache_aware) = default_policy - .as_any() - .downcast_ref::() - { - cache_aware.init_workers(&workers); - } - } + // Cache-aware policies are initialized in WorkerInitializer // Setup load monitoring for PowerOfTwo policy let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); + // Get default policy to check if we need load monitoring + let default_policy = ctx.policy_registry.get_default_policy(); + // Check if default policy is power_of_two for load monitoring let load_monitor_handle = if default_policy.name() == "power_of_two" { let monitor_urls = worker_urls.clone(); @@ -964,19 +958,13 @@ impl Router { // Notify PolicyRegistry about the new worker let model_id = worker_arc.model_id(); - let policy = self.policy_registry.on_worker_added(model_id, None); + self.policy_registry.on_worker_added(model_id, None); - // If this is a cache-aware policy, update it with all workers for this model - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::( - ) { - let model_workers = - self.worker_registry.get_by_model_fast(model_id); - cache_aware.init_workers(&model_workers); - } - } + // Initialize cache-aware policy if applicable + let model_workers = + self.worker_registry.get_by_model_fast(model_id); + self.policy_registry + .init_cache_aware_policy(model_id, &model_workers); worker_added = true; } @@ -1000,20 +988,12 @@ impl Router { // Notify PolicyRegistry about the new worker let model_id = worker_arc.model_id(); - let policy = self.policy_registry.on_worker_added(model_id, None); + self.policy_registry.on_worker_added(model_id, None); - // If this is a cache-aware policy, add this worker to it - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::( - ) { - // Get all workers for this model - let model_workers = - self.worker_registry.get_by_model_fast(model_id); - cache_aware.init_workers(&model_workers); - } - } + // Initialize cache-aware policy if applicable + let model_workers = self.worker_registry.get_by_model_fast(model_id); + self.policy_registry + .init_cache_aware_policy(model_id, &model_workers); } RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); @@ -1084,20 +1064,11 @@ impl Router { RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); - // If any models are using cache aware policy, remove the workers from the tree - // Check each removed worker's model and get its policy for dp_url in removed_workers.iter() { if let Some(worker) = self.worker_registry.get_by_url(dp_url) { let model_id = worker.model_id(); - if let Some(policy) = self.policy_registry.get_policy(model_id) { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - cache_aware.remove_worker_by_url(dp_url); - info!("Removed worker from cache-aware tree: {}", dp_url); - } - } + self.policy_registry + .remove_worker_from_cache_aware(model_id, dp_url); } } } else { @@ -1118,16 +1089,8 @@ impl Router { RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); } - // If the model is using cache aware policy, remove the worker from the tree - if let Some(policy) = self.policy_registry.get_policy(&model_id) { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - cache_aware.remove_worker_by_url(worker_url); - info!("Removed worker from cache-aware tree: {}", worker_url); - } - } + self.policy_registry + .remove_worker_from_cache_aware(&model_id, worker_url); } } diff --git a/sgl-router/src/routers/worker_initializer.rs b/sgl-router/src/routers/worker_initializer.rs index e71fba1c8..30e8fc865 100644 --- a/sgl-router/src/routers/worker_initializer.rs +++ b/sgl-router/src/routers/worker_initializer.rs @@ -3,9 +3,11 @@ use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode}; use crate::core::{ - BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry, + BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, Worker, WorkerRegistry, WorkerType, }; +use crate::policies::PolicyRegistry; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tracing::{info, warn}; @@ -19,6 +21,7 @@ impl WorkerInitializer { pub async fn initialize_workers( config: &RouterConfig, worker_registry: &Arc, + policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Initializing workers for routing mode: {:?}", config.mode); @@ -29,6 +32,7 @@ impl WorkerInitializer { &config.connection_mode, config, worker_registry, + policy_registry, ) .await?; } @@ -42,6 +46,7 @@ impl WorkerInitializer { &config.connection_mode, config, worker_registry, + policy_registry, ) .await?; Self::create_decode_workers( @@ -49,6 +54,7 @@ impl WorkerInitializer { &config.connection_mode, config, worker_registry, + policy_registry, ) .await?; } @@ -76,6 +82,7 @@ impl WorkerInitializer { config_connection_mode: &ConfigConnectionMode, config: &RouterConfig, registry: &Arc, + policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} regular workers", urls.len()); @@ -100,6 +107,8 @@ impl WorkerInitializer { success_threshold: config.health_check.success_threshold, }; + let mut registered_workers: HashMap>> = HashMap::new(); + for url in urls { // TODO: Add DP-aware support when we have dp_rank/dp_size info let worker = BasicWorkerBuilder::new(url.clone()) @@ -109,8 +118,28 @@ impl WorkerInitializer { .health_config(health_config.clone()) .build(); - let worker_id = registry.register(Arc::new(worker)); + let worker_arc = Arc::new(worker) as Arc; + let model_id = worker_arc.model_id(); + let worker_id = registry.register(Arc::clone(&worker_arc)); info!("Registered regular worker {} with ID {:?}", url, worker_id); + + // Track workers by model for cache-aware policy initialization + registered_workers + .entry(model_id.to_string()) + .or_default() + .push(Arc::clone(&worker_arc)); + + // Notify policy registry about the worker + if let Some(policy_reg) = policy_registry { + policy_reg.on_worker_added(model_id, None); + } + } + + // Initialize cache-aware policies with all workers for each model + if let Some(policy_reg) = policy_registry { + for (model_id, workers) in registered_workers { + policy_reg.init_cache_aware_policy(&model_id, &workers); + } } Ok(()) @@ -122,6 +151,7 @@ impl WorkerInitializer { config_connection_mode: &ConfigConnectionMode, config: &RouterConfig, registry: &Arc, + policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} prefill workers", prefill_entries.len()); @@ -149,6 +179,8 @@ impl WorkerInitializer { success_threshold: config.health_check.success_threshold, }; + let mut registered_workers: HashMap>> = HashMap::new(); + for (url, bootstrap_port) in prefill_entries { // TODO: Add DP-aware support when we have dp_rank/dp_size info let worker = BasicWorkerBuilder::new(url.clone()) @@ -160,8 +192,33 @@ impl WorkerInitializer { .health_config(health_config.clone()) .build(); - let worker_id = registry.register(Arc::new(worker)); + let worker_arc = Arc::new(worker) as Arc; + let model_id = worker_arc.model_id(); + let worker_id = registry.register(Arc::clone(&worker_arc)); info!("Registered prefill worker {} with ID {:?}", url, worker_id); + + // Track workers by model for cache-aware policy initialization + registered_workers + .entry(model_id.to_string()) + .or_default() + .push(Arc::clone(&worker_arc)); + + // Notify policy registry about the worker + if let Some(policy_reg) = policy_registry { + policy_reg.on_worker_added(model_id, None); + } + } + + // Initialize cache-aware policies for PD mode + if let Some(policy_reg) = policy_registry { + // Collect all prefill workers + let all_prefill_workers: Vec> = registered_workers + .values() + .flat_map(|workers| workers.iter().cloned()) + .collect(); + + // Initialize PD policies (will handle both prefill and decode, but we only have prefill here) + policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]); } Ok(()) @@ -173,6 +230,7 @@ impl WorkerInitializer { config_connection_mode: &ConfigConnectionMode, config: &RouterConfig, registry: &Arc, + policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} decode workers", urls.len()); @@ -197,6 +255,8 @@ impl WorkerInitializer { success_threshold: config.health_check.success_threshold, }; + let mut registered_workers: HashMap>> = HashMap::new(); + for url in urls { // TODO: Add DP-aware support when we have dp_rank/dp_size info let worker = BasicWorkerBuilder::new(url.clone()) @@ -206,8 +266,33 @@ impl WorkerInitializer { .health_config(health_config.clone()) .build(); - let worker_id = registry.register(Arc::new(worker)); + let worker_arc = Arc::new(worker) as Arc; + let model_id = worker_arc.model_id(); + let worker_id = registry.register(Arc::clone(&worker_arc)); info!("Registered decode worker {} with ID {:?}", url, worker_id); + + // Track workers by model for cache-aware policy initialization + registered_workers + .entry(model_id.to_string()) + .or_default() + .push(Arc::clone(&worker_arc)); + + // Notify policy registry about the worker + if let Some(policy_reg) = policy_registry { + policy_reg.on_worker_added(model_id, None); + } + } + + // Initialize cache-aware policies for PD mode + if let Some(policy_reg) = policy_registry { + // Collect all decode workers + let all_decode_workers: Vec> = registered_workers + .values() + .flat_map(|workers| workers.iter().cloned()) + .collect(); + + // Initialize PD policies (will handle both prefill and decode, but we only have decode here) + policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers); } Ok(()) @@ -281,7 +366,8 @@ impl WorkerInitializer { worker_type: WorkerType, config: &RouterConfig, registry: &Arc, - grpc_clients: &mut std::collections::HashMap, + policy_registry: Option<&Arc>, + grpc_clients: &mut HashMap, ) -> Result<(), String> { info!( "Creating {} gRPC workers of type {:?}", @@ -307,6 +393,8 @@ impl WorkerInitializer { success_threshold: config.health_check.success_threshold, }; + let mut registered_workers: HashMap>> = HashMap::new(); + for url in worker_urls { if let Some(client) = grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone()) @@ -317,13 +405,33 @@ impl WorkerInitializer { .grpc_client(client) .build(); - let worker_id = registry.register(Arc::new(worker)); + let worker_arc = Arc::new(worker) as Arc; + let model_id = worker_arc.model_id(); + let worker_id = registry.register(Arc::clone(&worker_arc)); info!("Registered gRPC worker {} with ID {:?}", url, worker_id); + + // Track workers by model for cache-aware policy initialization + registered_workers + .entry(model_id.to_string()) + .or_default() + .push(Arc::clone(&worker_arc)); + + // Notify policy registry about the worker + if let Some(policy_reg) = policy_registry { + policy_reg.on_worker_added(model_id, None); + } } else { warn!("No gRPC client available for worker {}, skipping", url); } } + // Initialize cache-aware policies with all workers for each model + if let Some(policy_reg) = policy_registry { + for (model_id, workers) in registered_workers { + policy_reg.init_cache_aware_policy(&model_id, &workers); + } + } + Ok(()) } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index f04e08183..21ec16d5e 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -595,15 +595,17 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box