From 2f173ea0744d2dca264afb4ad835c8e7dc2eb3b8 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 12 Sep 2025 19:18:27 -0400 Subject: [PATCH] [router] allow one router to support different model families and serving mode (#10244) --- sgl-router/py_src/sglang_router/router.py | 3 + .../py_src/sglang_router/router_args.py | 6 + sgl-router/py_test/e2e/conftest.py | 3 + .../py_test/integration/test_retries.py | 2 +- sgl-router/src/core/mod.rs | 2 + sgl-router/src/core/worker.rs | 110 ++- sgl-router/src/core/worker_registry.rs | 526 ++++++++++ sgl-router/src/policies/cache_aware.rs | 248 +++-- sgl-router/src/policies/mod.rs | 20 +- sgl-router/src/policies/power_of_two.rs | 16 +- sgl-router/src/policies/random.rs | 19 +- sgl-router/src/policies/registry.rs | 333 +++++++ sgl-router/src/policies/round_robin.rs | 25 +- sgl-router/src/protocols/mod.rs | 1 + sgl-router/src/protocols/worker_spec.rs | 198 ++++ sgl-router/src/routers/factory.rs | 45 +- sgl-router/src/routers/grpc/pd_router.rs | 17 +- sgl-router/src/routers/grpc/router.rs | 11 +- sgl-router/src/routers/http/openai_router.rs | 11 +- sgl-router/src/routers/http/pd_router.rs | 929 +++++++++--------- sgl-router/src/routers/http/router.rs | 457 +++++---- sgl-router/src/routers/mod.rs | 19 +- sgl-router/src/routers/router_manager.rs | 766 +++++++++++++++ sgl-router/src/server.rs | 272 ++++- sgl-router/src/service_discovery.rs | 13 +- .../tests/cache_aware_backward_compat_test.rs | 129 +++ .../tests/policy_registry_integration.rs | 168 ++++ sgl-router/tests/test_openai_routing.rs | 16 +- 28 files changed, 3528 insertions(+), 837 deletions(-) create mode 100644 sgl-router/src/core/worker_registry.rs create mode 100644 sgl-router/src/policies/registry.rs create mode 100644 sgl-router/src/protocols/worker_spec.rs create mode 100644 sgl-router/src/routers/router_manager.rs create mode 100644 sgl-router/tests/cache_aware_backward_compat_test.rs create mode 100644 sgl-router/tests/policy_registry_integration.rs diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 72a99ffbb..fea9e9dcd 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -46,6 +46,9 @@ class Router: max_payload_size: Maximum payload size in bytes. Default: 256MB max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 dp_aware: Enable data parallelism aware schedule. Default: False + enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled, + the router can manage multiple models simultaneously with per-model load balancing + policies. Default: False api_key: The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enabled. Default: None diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py index ad0a2ac9f..2b2c7427a 100644 --- a/sgl-router/py_src/sglang_router/router_args.py +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -34,6 +34,7 @@ class RouterArgs: max_tree_size: int = 2**26 max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches dp_aware: bool = False + enable_igw: bool = False # Enable IGW (Inter-Gateway) mode for multi-model support api_key: Optional[str] = None log_dir: Optional[str] = None log_level: Optional[str] = None @@ -227,6 +228,11 @@ class RouterArgs: action="store_true", help="Enable data parallelism aware schedule", ) + parser.add_argument( + f"--{prefix}enable-igw", + action="store_true", + help="Enable IGW (Inference-Gateway) mode for multi-model support", + ) parser.add_argument( f"--{prefix}api-key", type=str, diff --git a/sgl-router/py_test/e2e/conftest.py b/sgl-router/py_test/e2e/conftest.py index 866b27531..3acec82b2 100644 --- a/sgl-router/py_test/e2e/conftest.py +++ b/sgl-router/py_test/e2e/conftest.py @@ -128,6 +128,7 @@ def _popen_launch_router_only( timeout: float = 120.0, *, dp_aware: bool = False, + enable_igw: bool = False, api_key: str | None = None, ) -> subprocess.Popen: host, port = _parse_url(base_url) @@ -146,6 +147,8 @@ def _popen_launch_router_only( ] if dp_aware: cmd += ["--dp-aware"] + if enable_igw: + cmd += ["--enable-igw"] if api_key is not None: cmd += ["--api-key", api_key] cmd += [ diff --git a/sgl-router/py_test/integration/test_retries.py b/sgl-router/py_test/integration/test_retries.py index 5f3d4ffee..30826a665 100644 --- a/sgl-router/py_test/integration/test_retries.py +++ b/sgl-router/py_test/integration/test_retries.py @@ -35,7 +35,7 @@ def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers): ) assert r.status_code == 200 wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") - assert wid == id_b # should have retried onto healthy worker + assert wid in [id_b, id_c] # should have retried onto a healthy worker (B or C) # mock_workers fixture handles cleanup diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index b46810b4c..00e879449 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -11,6 +11,7 @@ pub mod error; pub mod retry; pub mod token_bucket; pub mod worker; +pub mod worker_registry; // Re-export commonly used types at the module level pub use circuit_breaker::{ @@ -22,3 +23,4 @@ pub use worker::{ start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, }; +pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 51c3cdd65..07279255f 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug { fn can_handle(&self, _req: &serde_json::Value) -> bool { true } + + // === Multi-router support === + + // TODO: - Enhanced Worker Discovery + // The Worker trait should handle async discovery of metadata from the worker itself + // rather than having service discovery or other components query /get_server_info. + // This keeps service discovery decoupled from worker-specific APIs. + // + // Proposed additions: + // - async fn discover_metadata(&mut self) -> Result<(), Error> + // Query /get_server_info and populate metadata labels with model_id, priority, cost, etc. + // - async fn validate_configuration(&self) -> Result<(), Error> + // Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC) + // - Make worker creation async to allow metadata discovery during initialization + // + // This way service discovery just calls router.add_worker() and the worker + // handles its own metadata discovery internally. + + /// Get the model ID this worker serves + fn model_id(&self) -> &str { + self.metadata() + .labels + .get("model_id") + .map(|s| s.as_str()) + .unwrap_or("unknown") + } + + /// Get the priority of this worker (higher value = higher priority) + fn priority(&self) -> u32 { + self.metadata() + .labels + .get("priority") + .and_then(|s| s.parse().ok()) + .unwrap_or(50) // Default priority is 50 (mid-range) + } + + /// Get the cost factor of this worker (1.0 = baseline) + fn cost(&self) -> f32 { + self.metadata() + .labels + .get("cost") + .and_then(|s| s.parse().ok()) + .unwrap_or(1.0) + } + + /// Get the tokenizer path for this worker (gRPC mode only) + fn tokenizer_path(&self) -> Option<&str> { + self.metadata() + .labels + .get("tokenizer_path") + .map(|s| s.as_str()) + } + + /// Get the reasoning parser type for this worker (gRPC mode only) + fn reasoning_parser(&self) -> Option<&str> { + self.metadata() + .labels + .get("reasoning_parser") + .map(|s| s.as_str()) + } + + /// Get the tool parser type for this worker (gRPC mode only) + fn tool_parser(&self) -> Option<&str> { + self.metadata() + .labels + .get("tool_parser") + .map(|s| s.as_str()) + } + + /// Get the chat template for this worker (gRPC mode only) + fn chat_template(&self) -> Option<&str> { + self.metadata() + .labels + .get("chat_template") + .map(|s| s.as_str()) + } } /// Connection mode for worker communication @@ -724,6 +800,21 @@ impl WorkerFactory { ) } + /// Create a regular worker with custom labels (for multi-router support) + pub fn create_regular_with_labels( + url: String, + labels: std::collections::HashMap, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular) + .with_circuit_breaker_config(circuit_breaker_config); + + // Add labels to metadata + worker.metadata.labels = labels; + + Box::new(worker) + } + /// Create a DP-aware worker of specified type pub fn create_dp_aware( base_url: String, @@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker { } impl HealthChecker { + /// Create a new HealthChecker + pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc) -> Self { + Self { handle, shutdown } + } + /// Shutdown the health checker gracefully pub async fn shutdown(self) { self.shutdown.store(true, Ordering::Release); @@ -950,7 +1046,7 @@ impl HealthChecker { /// Start an async background health checker for a collection of workers pub fn start_health_checker( - workers: std::sync::Arc>>>, + workers: std::sync::Arc>>>, check_interval_secs: u64, ) -> HealthChecker { let shutdown = Arc::new(AtomicBool::new(false)); @@ -1602,9 +1698,11 @@ mod tests { // Test HealthChecker background task #[tokio::test] async fn test_health_checker_startup() { - let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + let worker = Arc::new(BasicWorker::new( "http://w1:8080".to_string(), - )])); + WorkerType::Regular, + )) as Arc; + let workers = Arc::new(RwLock::new(vec![worker])); let checker = start_health_checker(workers.clone(), 60); @@ -1617,9 +1715,11 @@ mod tests { #[tokio::test] async fn test_health_checker_shutdown() { - let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + let worker = Arc::new(BasicWorker::new( "http://w1:8080".to_string(), - )])); + WorkerType::Regular, + )) as Arc; + let workers = Arc::new(RwLock::new(vec![worker])); let checker = start_health_checker(workers.clone(), 60); diff --git a/sgl-router/src/core/worker_registry.rs b/sgl-router/src/core/worker_registry.rs new file mode 100644 index 000000000..65a74cc02 --- /dev/null +++ b/sgl-router/src/core/worker_registry.rs @@ -0,0 +1,526 @@ +//! Worker Registry for multi-router support +//! +//! Provides centralized registry for workers with model-based indexing + +use crate::core::{ConnectionMode, Worker, WorkerType}; +use dashmap::DashMap; +use std::sync::{Arc, RwLock}; +use uuid::Uuid; + +/// Unique identifier for a worker +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct WorkerId(String); + +impl WorkerId { + /// Create a new worker ID + pub fn new() -> Self { + Self(Uuid::new_v4().to_string()) + } + + /// Create a worker ID from a string + pub fn from_string(s: String) -> Self { + Self(s) + } + + /// Get the ID as a string + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Default for WorkerId { + fn default() -> Self { + Self::new() + } +} + +/// Type alias for the model index to reduce complexity +type ModelIndex = Arc>>>>>; + +/// Worker registry with model-based indexing +#[derive(Debug)] +pub struct WorkerRegistry { + /// All workers indexed by ID + workers: Arc>>, + + /// Workers indexed by model ID (stores WorkerId for reference) + model_workers: Arc>>, + + /// Optimized model index for O(1) lookups (stores Arc directly) + model_index: ModelIndex, + + /// Workers indexed by worker type + type_workers: Arc>>, + + /// Workers indexed by connection mode + connection_workers: Arc>>, + + /// URL to worker ID mapping (for backward compatibility) + url_to_id: Arc>, +} + +impl WorkerRegistry { + /// Create a new worker registry + pub fn new() -> Self { + Self { + workers: Arc::new(DashMap::new()), + model_workers: Arc::new(DashMap::new()), + model_index: Arc::new(DashMap::new()), + type_workers: Arc::new(DashMap::new()), + connection_workers: Arc::new(DashMap::new()), + url_to_id: Arc::new(DashMap::new()), + } + } + + /// Register a new worker + pub fn register(&self, worker: Arc) -> WorkerId { + let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) { + // Worker with this URL already exists, update it + existing_id.clone() + } else { + WorkerId::new() + }; + + // Store worker + self.workers.insert(worker_id.clone(), worker.clone()); + + // Update URL mapping + self.url_to_id + .insert(worker.url().to_string(), worker_id.clone()); + + // Update model index (both ID-based and optimized) + let model_id = worker.model_id().to_string(); + self.model_workers + .entry(model_id.clone()) + .or_default() + .push(worker_id.clone()); + + // Update optimized model index for O(1) lookups + self.model_index + .entry(model_id) + .or_insert_with(|| Arc::new(RwLock::new(Vec::new()))) + .write() + .expect("RwLock for model_index is poisoned") + .push(worker.clone()); + + // Update type index + self.type_workers + .entry(worker.worker_type()) + .or_default() + .push(worker_id.clone()); + + // Update connection mode index + self.connection_workers + .entry(worker.connection_mode()) + .or_default() + .push(worker_id.clone()); + + worker_id + } + + /// Remove a worker by ID + pub fn remove(&self, worker_id: &WorkerId) -> Option> { + if let Some((_, worker)) = self.workers.remove(worker_id) { + // Remove from URL mapping + self.url_to_id.remove(worker.url()); + + // Remove from model index (both ID-based and optimized) + if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) { + model_workers.retain(|id| id != worker_id); + } + + // Remove from optimized model index + if let Some(model_index_entry) = self.model_index.get(worker.model_id()) { + let worker_url = worker.url(); + model_index_entry + .write() + .expect("RwLock for model_index is poisoned") + .retain(|w| w.url() != worker_url); + } + + // Remove from type index + if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) { + type_workers.retain(|id| id != worker_id); + } + + // Remove from connection mode index + if let Some(mut conn_workers) = + self.connection_workers.get_mut(&worker.connection_mode()) + { + conn_workers.retain(|id| id != worker_id); + } + + Some(worker) + } else { + None + } + } + + /// Remove a worker by URL + pub fn remove_by_url(&self, url: &str) -> Option> { + if let Some((_, worker_id)) = self.url_to_id.remove(url) { + self.remove(&worker_id) + } else { + None + } + } + + /// Get a worker by ID + pub fn get(&self, worker_id: &WorkerId) -> Option> { + self.workers.get(worker_id).map(|entry| entry.clone()) + } + + /// Get a worker by URL + pub fn get_by_url(&self, url: &str) -> Option> { + self.url_to_id.get(url).and_then(|id| self.get(&id)) + } + + /// Get all workers for a model + pub fn get_by_model(&self, model_id: &str) -> Vec> { + self.model_workers + .get(model_id) + .map(|ids| ids.iter().filter_map(|id| self.get(id)).collect()) + .unwrap_or_default() + } + + /// Get all workers for a model (O(1) optimized version) + /// This method uses the pre-indexed model_index for fast lookups + pub fn get_by_model_fast(&self, model_id: &str) -> Vec> { + self.model_index + .get(model_id) + .map(|workers| { + workers + .read() + .expect("RwLock for model_index is poisoned") + .clone() + }) + .unwrap_or_default() + } + + /// Get all workers by worker type + pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec> { + self.type_workers + .get(worker_type) + .map(|ids| ids.iter().filter_map(|id| self.get(id)).collect()) + .unwrap_or_default() + } + + /// Get all prefill workers (regardless of bootstrap_port) + pub fn get_prefill_workers(&self) -> Vec> { + self.workers + .iter() + .filter_map(|entry| { + let worker = entry.value(); + match worker.worker_type() { + WorkerType::Prefill { .. } => Some(worker.clone()), + _ => None, + } + }) + .collect() + } + + /// Get all decode workers + pub fn get_decode_workers(&self) -> Vec> { + self.get_by_type(&WorkerType::Decode) + } + + /// Get all workers by connection mode + pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec> { + self.connection_workers + .get(connection_mode) + .map(|ids| ids.iter().filter_map(|id| self.get(id)).collect()) + .unwrap_or_default() + } + + /// Get all workers + pub fn get_all(&self) -> Vec> { + self.workers + .iter() + .map(|entry| entry.value().clone()) + .collect() + } + + /// Get all workers with their IDs + pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc)> { + self.workers + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect() + } + + /// Get all worker URLs + pub fn get_all_urls(&self) -> Vec { + self.workers + .iter() + .map(|entry| entry.value().url().to_string()) + .collect() + } + + /// Get all model IDs with workers + pub fn get_models(&self) -> Vec { + self.model_workers + .iter() + .filter(|entry| !entry.value().is_empty()) + .map(|entry| entry.key().clone()) + .collect() + } + + /// Get workers filtered by multiple criteria + /// + /// This method allows flexible filtering of workers based on: + /// - model_id: Filter by specific model + /// - worker_type: Filter by worker type (Regular, Prefill, Decode) + /// - connection_mode: Filter by connection mode (Http, Grpc) + /// - healthy_only: Only return healthy workers + pub fn get_workers_filtered( + &self, + model_id: Option<&str>, + worker_type: Option, + connection_mode: Option, + healthy_only: bool, + ) -> Vec> { + // Start with the most efficient collection based on filters + // Use model index when possible as it's O(1) lookup + let workers = if let Some(model) = model_id { + self.get_by_model_fast(model) + } else { + self.get_all() + }; + + // Apply remaining filters + workers + .into_iter() + .filter(|w| { + // Check worker_type if specified + if let Some(ref wtype) = worker_type { + if w.worker_type() != *wtype { + return false; + } + } + + // Check connection_mode if specified + if let Some(ref conn) = connection_mode { + if w.connection_mode() != *conn { + return false; + } + } + + // Check health if required + if healthy_only && !w.is_healthy() { + return false; + } + + true + }) + .collect() + } + + /// Get worker statistics + pub fn stats(&self) -> WorkerRegistryStats { + let total_workers = self.workers.len(); + let total_models = self.get_models().len(); + + let mut healthy_count = 0; + let mut total_load = 0; + let mut regular_count = 0; + let mut prefill_count = 0; + let mut decode_count = 0; + + for worker in self.get_all() { + if worker.is_healthy() { + healthy_count += 1; + } + total_load += worker.load(); + + match worker.worker_type() { + WorkerType::Regular => regular_count += 1, + WorkerType::Prefill { .. } => prefill_count += 1, + WorkerType::Decode => decode_count += 1, + } + } + + WorkerRegistryStats { + total_workers, + total_models, + healthy_workers: healthy_count, + total_load, + regular_workers: regular_count, + prefill_workers: prefill_count, + decode_workers: decode_count, + } + } + + /// Start a health checker for all workers in the registry + /// This should be called once after the registry is populated with workers + pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown_clone = shutdown.clone(); + let workers_ref = self.workers.clone(); + + let handle = tokio::spawn(async move { + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs)); + + // Counter for periodic load reset (every 10 health check cycles) + let mut check_count = 0u64; + const LOAD_RESET_INTERVAL: u64 = 10; + + loop { + interval.tick().await; + + // Check for shutdown signal + if shutdown_clone.load(Ordering::Acquire) { + tracing::debug!("Registry health checker shutting down"); + break; + } + + // Get all workers from registry + let workers: Vec> = workers_ref + .iter() + .map(|entry| entry.value().clone()) + .collect(); + + // Perform health checks + for worker in &workers { + let _ = worker.check_health_async().await; // Use async version directly + } + + // Reset loads periodically + check_count += 1; + if check_count % LOAD_RESET_INTERVAL == 0 { + tracing::debug!("Resetting worker loads (cycle {})", check_count); + for worker in &workers { + worker.reset_load(); + } + } + } + }); + + crate::core::HealthChecker::new(handle, shutdown) + } +} + +impl Default for WorkerRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Statistics for the worker registry +#[derive(Debug, Clone)] +pub struct WorkerRegistryStats { + pub total_workers: usize, + pub total_models: usize, + pub healthy_workers: usize, + pub total_load: usize, + pub regular_workers: usize, + pub prefill_workers: usize, + pub decode_workers: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{CircuitBreakerConfig, WorkerFactory}; + use std::collections::HashMap; + + #[test] + fn test_worker_registry() { + let registry = WorkerRegistry::new(); + + // Create a worker with labels + let mut labels = HashMap::new(); + labels.insert("model_id".to_string(), "llama-3-8b".to_string()); + labels.insert("priority".to_string(), "50".to_string()); + labels.insert("cost".to_string(), "0.8".to_string()); + + let worker = WorkerFactory::create_regular_with_labels( + "http://worker1:8080".to_string(), + labels, + CircuitBreakerConfig::default(), + ); + + // Register worker (WorkerFactory returns Box, convert to Arc) + let worker_id = registry.register(Arc::from(worker)); + + // Verify registration + assert!(registry.get(&worker_id).is_some()); + assert!(registry.get_by_url("http://worker1:8080").is_some()); + assert_eq!(registry.get_by_model("llama-3-8b").len(), 1); + assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1); + assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1); + + // Test stats + let stats = registry.stats(); + assert_eq!(stats.total_workers, 1); + assert_eq!(stats.total_models, 1); + + // Remove worker + registry.remove(&worker_id); + assert!(registry.get(&worker_id).is_none()); + } + + #[test] + fn test_model_index_fast_lookup() { + let registry = WorkerRegistry::new(); + + // Create workers for different models + let mut labels1 = HashMap::new(); + labels1.insert("model_id".to_string(), "llama-3".to_string()); + let worker1 = WorkerFactory::create_regular_with_labels( + "http://worker1:8080".to_string(), + labels1, + CircuitBreakerConfig::default(), + ); + + let mut labels2 = HashMap::new(); + labels2.insert("model_id".to_string(), "llama-3".to_string()); + let worker2 = WorkerFactory::create_regular_with_labels( + "http://worker2:8080".to_string(), + labels2, + CircuitBreakerConfig::default(), + ); + + let mut labels3 = HashMap::new(); + labels3.insert("model_id".to_string(), "gpt-4".to_string()); + let worker3 = WorkerFactory::create_regular_with_labels( + "http://worker3:8080".to_string(), + labels3, + CircuitBreakerConfig::default(), + ); + + // Register workers + registry.register(Arc::from(worker1)); + registry.register(Arc::from(worker2)); + registry.register(Arc::from(worker3)); + + // Test get_by_model_fast for llama-3 + let llama_workers = registry.get_by_model_fast("llama-3"); + assert_eq!(llama_workers.len(), 2); + let urls: Vec = llama_workers.iter().map(|w| w.url().to_string()).collect(); + assert!(urls.contains(&"http://worker1:8080".to_string())); + assert!(urls.contains(&"http://worker2:8080".to_string())); + + // Test get_by_model_fast for gpt-4 + let gpt_workers = registry.get_by_model_fast("gpt-4"); + assert_eq!(gpt_workers.len(), 1); + assert_eq!(gpt_workers[0].url(), "http://worker3:8080"); + + // Test get_by_model_fast for non-existent model + let unknown_workers = registry.get_by_model_fast("unknown-model"); + assert_eq!(unknown_workers.len(), 0); + + // Test that both get_by_model and get_by_model_fast return same results + let llama_workers_slow = registry.get_by_model("llama-3"); + assert_eq!(llama_workers.len(), llama_workers_slow.len()); + + // Test removal updates the model index + registry.remove_by_url("http://worker1:8080"); + let llama_workers_after = registry.get_by_model_fast("llama-3"); + assert_eq!(llama_workers_after.len(), 1); + assert_eq!(llama_workers_after[0].url(), "http://worker2:8080"); + } +} diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 47d95c835..cf59c5e07 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; use crate::core::Worker; use crate::metrics::RouterMetrics; use crate::tree::Tree; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; @@ -72,10 +73,11 @@ use tracing::debug; /// /// Routes requests based on cache affinity when load is balanced, /// switches to shortest-queue routing when load is imbalanced. +/// Maintains separate trees per model for multi-model support. #[derive(Debug)] pub struct CacheAwarePolicy { config: CacheAwareConfig, - tree: Arc>, + trees: Arc>>, // model_id -> Tree eviction_handle: Option>, } @@ -85,20 +87,26 @@ impl CacheAwarePolicy { } pub fn with_config(config: CacheAwareConfig) -> Self { - let tree = Arc::new(Mutex::new(Tree::new())); + let trees = Arc::new(Mutex::new(HashMap::::new())); // Start background eviction thread if configured let eviction_handle = if config.eviction_interval_secs > 0 { - let tree_clone = Arc::clone(&tree); + let trees_clone = Arc::clone(&trees); let max_tree_size = config.max_tree_size; let interval = config.eviction_interval_secs; Some(thread::spawn(move || loop { thread::sleep(Duration::from_secs(interval)); - if let Ok(tree_guard) = tree_clone.lock() { - tree_guard.evict_tenant_by_size(max_tree_size); - debug!("Cache eviction completed, max_size: {}", max_tree_size); + if let Ok(mut trees_guard) = trees_clone.lock() { + // Evict for all model trees + for (model_id, tree) in trees_guard.iter_mut() { + tree.evict_tenant_by_size(max_tree_size); + debug!( + "Cache eviction completed for model {}, max_size: {}", + model_id, max_tree_size + ); + } } })) } else { @@ -107,38 +115,97 @@ impl CacheAwarePolicy { Self { config, - tree, + trees, eviction_handle, } } /// Initialize the tree with worker URLs (used only during initial setup) - pub fn init_workers(&self, workers: &[Box]) { - if let Ok(tree) = self.tree.lock() { + pub fn init_workers(&self, workers: &[Arc]) { + if let Ok(mut trees) = self.trees.lock() { + // Group workers by model + let mut model_workers: HashMap>> = HashMap::new(); for worker in workers { - tree.insert("", worker.url()); + // Use "default" for unknown/empty model_ids for backward compatibility + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + model_workers.entry(tree_key).or_default().push(worker); + } + + // Initialize tree for each model + for (tree_key, model_workers) in model_workers { + let tree = trees.entry(tree_key).or_insert_with(Tree::new); + for worker in model_workers { + tree.insert("", worker.url()); + } } } } /// Add a single worker to the tree (incremental update) - pub fn add_worker(&self, url: &str) { - if let Ok(tree) = self.tree.lock() { + pub fn add_worker(&self, worker: &dyn Worker) { + if let Ok(mut trees) = self.trees.lock() { + // For backward compatibility: if model_id is "unknown" or empty, + // use a default tree. This preserves existing behavior for single-model routers. + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + let tree = trees.entry(tree_key).or_insert_with(Tree::new); + tree.insert("", worker.url()); + } + } + + /// Add a worker by URL and model (for backward compatibility) + pub fn add_worker_by_url(&self, url: &str, model_id: &str) { + if let Ok(mut trees) = self.trees.lock() { + let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new); tree.insert("", url); } } /// Remove a worker from the tree - pub fn remove_worker(&self, url: &str) { - if let Ok(tree) = self.tree.lock() { - tree.remove_tenant(url); + pub fn remove_worker(&self, worker: &dyn Worker) { + if let Ok(mut trees) = self.trees.lock() { + // Use same logic as add_worker for consistency + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + if let Some(tree) = trees.get_mut(&tree_key) { + tree.remove_tenant(worker.url()); + } + } + } + + /// Remove a worker by URL (removes from all model trees for backward compatibility) + pub fn remove_worker_by_url(&self, url: &str) { + if let Ok(mut trees) = self.trees.lock() { + // Remove from all trees since we don't know which model it belongs to + for (_model_id, tree) in trees.iter_mut() { + tree.remove_tenant(url); + } } } /// Run cache eviction to prevent unbounded growth pub fn evict_cache(&self, max_size: usize) { - if let Ok(tree) = self.tree.lock() { - tree.evict_tenant_by_size(max_size); + if let Ok(mut trees) = self.trees.lock() { + for (model_id, tree) in trees.iter_mut() { + tree.evict_tenant_by_size(max_size); + debug!( + "Cache eviction for model {}, max_size: {}", + model_id, max_size + ); + } } } } @@ -146,7 +213,7 @@ impl CacheAwarePolicy { impl LoadBalancingPolicy for CacheAwarePolicy { fn select_worker( &self, - workers: &[Box], + workers: &[Arc], request_text: Option<&str>, ) -> Option { let healthy_indices = get_healthy_worker_indices(workers); @@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy { return None; } + // Group workers by model (using "default" for unknown/empty model_ids) + let mut model_workers: HashMap> = HashMap::new(); + for idx in &healthy_indices { + let model_id = workers[*idx].model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + model_workers.entry(tree_key).or_default().push(*idx); + } + // Get current load statistics let loads: Vec = workers.iter().map(|w| w.load()).collect(); let max_load = *loads.iter().max().unwrap_or(&0); @@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Even in imbalanced mode, update the tree to maintain cache state if let Some(text) = request_text { - if let Ok(tree) = self.tree.lock() { + if let Ok(mut trees) = self.trees.lock() { + let model_id = workers[min_load_idx].model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + let tree = trees.entry(tree_key).or_insert_with(Tree::new); tree.insert(text, workers[min_load_idx].url()); } } @@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Use cache-aware routing when balanced let text = request_text.unwrap_or(""); - if let Ok(tree) = self.tree.lock() { - let (matched_text, matched_worker) = tree.prefix_match(text); - let match_rate = if text.is_empty() { - 0.0 - } else { - matched_text.chars().count() as f32 / text.chars().count() as f32 - }; + if let Ok(mut trees) = self.trees.lock() { + let mut best_match_idx: Option = None; + let mut best_match_rate: f32 = 0.0; - let selected_url = if match_rate > self.config.cache_threshold { - RouterMetrics::record_cache_hit(); - matched_worker.to_string() - } else { - RouterMetrics::record_cache_miss(); - tree.get_smallest_tenant() - }; + // Find best match across all models + for (model_id, worker_indices) in &model_workers { + let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new); - // Find the index of the selected worker - if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) { - // Only proceed if the worker is healthy - if workers[selected_idx].is_healthy() { - // Update the tree with this request - tree.insert(text, &selected_url); + let (matched_text, matched_worker) = tree.prefix_match(text); + let match_rate = if text.is_empty() { + 0.0 + } else { + matched_text.chars().count() as f32 / text.chars().count() as f32 + }; - // Increment processed counter - workers[selected_idx].increment_processed(); - RouterMetrics::record_processed_request(&selected_url); - - return Some(selected_idx); + // Check if this model has the best match + if match_rate > best_match_rate { + // Find the worker index for this URL + if let Some(idx) = worker_indices + .iter() + .find(|&&idx| workers[idx].url() == matched_worker) + { + best_match_idx = Some(*idx); + best_match_rate = match_rate; + } } - } else { - // Selected worker no longer exists, remove it from tree - tree.remove_tenant(&selected_url); - debug!("Removed stale worker {} from cache tree", selected_url); } - // Fallback to first healthy worker - return healthy_indices.first().copied(); + // Select worker based on cache threshold + let selected_idx = if let (Some(idx), true) = ( + best_match_idx, + best_match_rate > self.config.cache_threshold, + ) { + RouterMetrics::record_cache_hit(); + idx + } else { + RouterMetrics::record_cache_miss(); + + // Find model with smallest tree (most cache capacity) + let mut smallest_tree_model = String::new(); + let mut smallest_tree_size = usize::MAX; + + for model_id in model_workers.keys() { + let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new); + let size = tree.get_used_size_per_tenant().values().sum::(); + if size < smallest_tree_size { + smallest_tree_size = size; + smallest_tree_model = model_id.clone(); + } + } + + // Select least loaded worker from model with most cache capacity + if let Some(worker_indices) = model_workers.get(&smallest_tree_model) { + worker_indices + .iter() + .min_by_key(|&&idx| workers[idx].load()) + .copied() + .unwrap_or(healthy_indices[0]) + } else { + healthy_indices[0] + } + }; + + // Update the tree with this request + let model_id = workers[selected_idx].model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default".to_string() + } else { + model_id.to_string() + }; + let tree = trees.entry(tree_key).or_insert_with(Tree::new); + tree.insert(text, workers[selected_idx].url()); + + // Increment processed counter + workers[selected_idx].increment_processed(); + RouterMetrics::record_processed_request(workers[selected_idx].url()); + RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url()); + + return Some(selected_idx); } // Fallback to first healthy worker if tree operations fail @@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy { fn select_worker_pair( &self, - prefill_workers: &[Box], - decode_workers: &[Box], + prefill_workers: &[Arc], + decode_workers: &[Arc], request_text: Option<&str>, ) -> Option<(usize, usize)> { // DEPRECATED: This method is no longer used when separate policies are configured. @@ -333,12 +461,12 @@ mod tests { ..Default::default() }; let policy = CacheAwarePolicy::with_config(config); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), @@ -378,7 +506,7 @@ mod tests { } // worker2 has load 0 - let workers: Vec> = vec![Box::new(worker1), Box::new(worker2)]; + let workers: Vec> = vec![Arc::new(worker1), Arc::new(worker2)]; policy.init_workers(&workers); // Should select worker2 (lower load) despite cache affinity @@ -395,12 +523,12 @@ mod tests { ..Default::default() }; let policy = CacheAwarePolicy::with_config(config); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), @@ -413,7 +541,7 @@ mod tests { policy.select_worker(&workers, Some("test2")); // Remove a worker - policy.remove_worker("http://w1:8000"); + policy.remove_worker_by_url("http://w1:8000"); workers[0].set_healthy(false); // All requests should now go to worker2 diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs index 97ce9ca6f..7fdf03ba3 100644 --- a/sgl-router/src/policies/mod.rs +++ b/sgl-router/src/policies/mod.rs @@ -5,17 +5,20 @@ use crate::core::Worker; use std::fmt::Debug; +use std::sync::Arc; mod cache_aware; mod factory; mod power_of_two; mod random; +mod registry; mod round_robin; pub use cache_aware::CacheAwarePolicy; pub use factory::PolicyFactory; pub use power_of_two::PowerOfTwoPolicy; pub use random::RandomPolicy; +pub use registry::PolicyRegistry; pub use round_robin::RoundRobinPolicy; /// Core trait for load balancing policies @@ -26,9 +29,10 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug { /// Select a single worker from the available workers /// /// This is used for regular routing mode where requests go to a single worker. + /// Now uses Arc for better performance and to avoid unnecessary cloning. fn select_worker( &self, - workers: &[Box], + workers: &[Arc], request_text: Option<&str>, ) -> Option; @@ -38,8 +42,8 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug { /// Default implementation uses select_worker for each array independently. fn select_worker_pair( &self, - prefill_workers: &[Box], - decode_workers: &[Box], + prefill_workers: &[Arc], + decode_workers: &[Arc], request_text: Option<&str>, ) -> Option<(usize, usize)> { // Default implementation: independently select from each pool @@ -105,7 +109,7 @@ impl Default for CacheAwareConfig { } /// Helper function to filter healthy workers and return their indices -pub(crate) fn get_healthy_worker_indices(workers: &[Box]) -> Vec { +pub(crate) fn get_healthy_worker_indices(workers: &[Arc]) -> Vec { workers .iter() .enumerate() @@ -121,16 +125,16 @@ mod tests { #[test] fn test_get_healthy_worker_indices() { - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w3:8000".to_string(), WorkerType::Regular, )), diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs index c10fc2949..6452cdc6f 100644 --- a/sgl-router/src/policies/power_of_two.rs +++ b/sgl-router/src/policies/power_of_two.rs @@ -5,7 +5,7 @@ use crate::core::Worker; use crate::metrics::RouterMetrics; use rand::Rng; use std::collections::HashMap; -use std::sync::RwLock; +use std::sync::{Arc, RwLock}; use tracing::info; /// Power-of-two choices policy @@ -41,7 +41,7 @@ impl PowerOfTwoPolicy { impl LoadBalancingPolicy for PowerOfTwoPolicy { fn select_worker( &self, - workers: &[Box], + workers: &[Arc], _request_text: Option<&str>, ) -> Option { let healthy_indices = get_healthy_worker_indices(workers); @@ -137,8 +137,8 @@ mod tests { } // worker3 has load 0 - let workers: Vec> = - vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)]; + let workers: Vec> = + vec![Arc::new(worker1), Arc::new(worker2), Arc::new(worker3)]; // Run multiple selections let mut selected_counts = [0; 3]; @@ -156,12 +156,12 @@ mod tests { #[test] fn test_power_of_two_with_cached_loads() { let policy = PowerOfTwoPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), @@ -190,7 +190,7 @@ mod tests { #[test] fn test_power_of_two_single_worker() { let policy = PowerOfTwoPolicy::new(); - let workers: Vec> = vec![Box::new(BasicWorker::new( + let workers: Vec> = vec![Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, ))]; diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs index 4912d0dd2..11636c045 100644 --- a/sgl-router/src/policies/random.rs +++ b/sgl-router/src/policies/random.rs @@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use crate::core::Worker; use crate::metrics::RouterMetrics; use rand::Rng; +use std::sync::Arc; /// Random selection policy /// @@ -20,7 +21,7 @@ impl RandomPolicy { impl LoadBalancingPolicy for RandomPolicy { fn select_worker( &self, - workers: &[Box], + workers: &[Arc], _request_text: Option<&str>, ) -> Option { let healthy_indices = get_healthy_worker_indices(workers); @@ -56,16 +57,16 @@ mod tests { #[test] fn test_random_selection() { let policy = RandomPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w3:8000".to_string(), WorkerType::Regular, )), @@ -87,12 +88,12 @@ mod tests { #[test] fn test_random_with_unhealthy_workers() { let policy = RandomPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), @@ -110,7 +111,7 @@ mod tests { #[test] fn test_random_no_healthy_workers() { let policy = RandomPolicy::new(); - let workers: Vec> = vec![Box::new(BasicWorker::new( + let workers: Vec> = vec![Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, ))]; diff --git a/sgl-router/src/policies/registry.rs b/sgl-router/src/policies/registry.rs new file mode 100644 index 000000000..326b29d76 --- /dev/null +++ b/sgl-router/src/policies/registry.rs @@ -0,0 +1,333 @@ +/// Policy Registry for managing model-to-policy mappings +/// +/// This registry manages the dynamic assignment of load balancing policies to models. +/// When the first worker of a new model is added, it determines the policy for that model. +/// All subsequent workers of the same model use the established policy. +/// When the last worker of a model is removed, the policy mapping is cleaned up. +use super::{ + CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, + RoundRobinPolicy, +}; +use crate::config::types::PolicyConfig; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use tracing::{debug, info, warn}; + +/// Registry for managing model-to-policy mappings +#[derive(Clone)] +pub struct PolicyRegistry { + /// Model ID -> Policy instance mapping + model_policies: Arc>>>, + + /// Model ID -> Worker count for cleanup tracking + model_worker_counts: Arc>>, + + /// Default policy instance (cached) + default_policy: Arc, + + /// Prefill policy for PD mode + prefill_policy: Arc>>>, + + /// Decode policy for PD mode + decode_policy: Arc>>>, +} + +impl PolicyRegistry { + /// Create a new PolicyRegistry with a default policy + pub fn new(default_policy_config: PolicyConfig) -> Self { + let default_policy = Self::create_policy_from_config(&default_policy_config); + + Self { + model_policies: Arc::new(RwLock::new(HashMap::new())), + model_worker_counts: Arc::new(RwLock::new(HashMap::new())), + default_policy, + prefill_policy: Arc::new(RwLock::new(None)), + decode_policy: Arc::new(RwLock::new(None)), + } + } + + /// Called when a worker is added + /// Returns the policy that should be used for this worker's model + pub fn on_worker_added( + &self, + model_id: &str, + policy_hint: Option<&str>, + ) -> Arc { + // Increment worker count + { + let mut counts = self.model_worker_counts.write().unwrap(); + *counts.entry(model_id.to_string()).or_insert(0) += 1; + debug!( + "Worker added for model {}, count: {}", + model_id, + counts.get(model_id).unwrap() + ); + } + + // Check if model already has a policy + { + let policies = self.model_policies.read().unwrap(); + if let Some(existing_policy) = policies.get(model_id) { + debug!( + "Model {} already has policy: {}", + model_id, + existing_policy.name() + ); + return Arc::clone(existing_policy); + } + } + + // New model - determine policy + let policy = self.determine_policy_for_model(model_id, policy_hint); + + info!( + "Assigning policy {} to new model {}", + policy.name(), + model_id + ); + + // Store policy for this model + { + let mut policies = self.model_policies.write().unwrap(); + policies.insert(model_id.to_string(), Arc::clone(&policy)); + } + + policy + } + + /// Called when a worker is removed + pub fn on_worker_removed(&self, model_id: &str) { + let should_cleanup = { + let mut counts = self.model_worker_counts.write().unwrap(); + if let Some(count) = counts.get_mut(model_id) { + *count = count.saturating_sub(1); + debug!("Worker removed for model {}, count: {}", model_id, *count); + if *count == 0 { + counts.remove(model_id); + true + } else { + false + } + } else { + warn!( + "Attempted to remove worker for model {} with no registered workers", + model_id + ); + false + } + }; + + // Clean up policy if this was the last worker + if should_cleanup { + let mut policies = self.model_policies.write().unwrap(); + if let Some(policy) = policies.remove(model_id) { + info!( + "Removed policy {} for model {} (last worker removed)", + policy.name(), + model_id + ); + // Policy will be dropped here, cleaning up any resources + drop(policy); + } + } + } + + /// Get the policy for a model + pub fn get_policy(&self, model_id: &str) -> Option> { + self.model_policies.read().unwrap().get(model_id).cloned() + } + + /// Get the default policy + pub fn get_default_policy(&self) -> Arc { + Arc::clone(&self.default_policy) + } + + /// Get policy for a model, or default if not found + pub fn get_policy_or_default(&self, model_id: &str) -> Arc { + self.get_policy(model_id) + .unwrap_or_else(|| self.get_default_policy()) + } + + /// Determine policy for a new model + fn determine_policy_for_model( + &self, + model_id: &str, + policy_hint: Option<&str>, + ) -> Arc { + // 1. Check policy hint from worker + if let Some(policy_type) = policy_hint { + debug!("Using policy hint '{}' for model {}", policy_type, model_id); + return self.create_policy_from_type(policy_type); + } + + // 2. Use default policy + debug!("Using default policy for model {}", model_id); + Arc::clone(&self.default_policy) + } + + /// Create a policy from a type string + fn create_policy_from_type(&self, policy_type: &str) -> Arc { + match policy_type { + "round_robin" => Arc::new(RoundRobinPolicy::new()), + "random" => Arc::new(RandomPolicy::new()), + "cache_aware" => Arc::new(CacheAwarePolicy::new()), + "power_of_two" => Arc::new(PowerOfTwoPolicy::new()), + _ => { + warn!("Unknown policy type '{}', using default", policy_type); + Arc::clone(&self.default_policy) + } + } + } + + /// Create a policy from a PolicyConfig + fn create_policy_from_config(config: &PolicyConfig) -> Arc { + match config { + PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()), + PolicyConfig::Random => Arc::new(RandomPolicy::new()), + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + let cache_config = CacheAwareConfig { + cache_threshold: *cache_threshold, + balance_abs_threshold: *balance_abs_threshold, + balance_rel_threshold: *balance_rel_threshold, + eviction_interval_secs: *eviction_interval_secs, + max_tree_size: *max_tree_size, + }; + Arc::new(CacheAwarePolicy::with_config(cache_config)) + } + PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()), + } + } + + /// Get current model->policy mappings (for debugging/monitoring) + pub fn get_all_mappings(&self) -> HashMap { + let policies = self.model_policies.read().unwrap(); + policies + .iter() + .map(|(model, policy)| (model.clone(), policy.name().to_string())) + .collect() + } + + /// Get worker counts per model + pub fn get_worker_counts(&self) -> HashMap { + self.model_worker_counts.read().unwrap().clone() + } + + /// Clear all policies (useful for testing) + pub fn clear(&self) { + let mut policies = self.model_policies.write().unwrap(); + policies.clear(); + let mut counts = self.model_worker_counts.write().unwrap(); + counts.clear(); + } + + /// Set the prefill policy for PD mode + pub fn set_prefill_policy(&self, policy: Arc) { + let mut prefill_policy = self.prefill_policy.write().unwrap(); + *prefill_policy = Some(policy); + } + + /// Set the decode policy for PD mode + pub fn set_decode_policy(&self, policy: Arc) { + let mut decode_policy = self.decode_policy.write().unwrap(); + *decode_policy = Some(policy); + } + + /// Get the prefill policy for PD mode, or default if not set + pub fn get_prefill_policy(&self) -> Arc { + let prefill_policy = self.prefill_policy.read().unwrap(); + prefill_policy + .as_ref() + .map(Arc::clone) + .unwrap_or_else(|| self.get_default_policy()) + } + + /// Get the decode policy for PD mode, or default if not set + pub fn get_decode_policy(&self) -> Arc { + let decode_policy = self.decode_policy.read().unwrap(); + decode_policy + .as_ref() + .map(Arc::clone) + .unwrap_or_else(|| self.get_default_policy()) + } +} + +impl std::fmt::Debug for PolicyRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PolicyRegistry") + .field("model_policies", &self.model_policies) + .field("model_worker_counts", &self.model_worker_counts) + .field("default_policy", &self.default_policy.name()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_policy_registry_basic() { + let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); + + // First worker of a model sets the policy + let policy1 = registry.on_worker_added("llama-3", Some("cache_aware")); + assert_eq!(policy1.name(), "cache_aware"); + + // Second worker of same model uses existing policy + let policy2 = registry.on_worker_added("llama-3", Some("round_robin")); + assert_eq!(policy2.name(), "cache_aware"); // Ignores hint, uses existing + + // Different model can have different policy + let policy3 = registry.on_worker_added("gpt-4", Some("random")); + assert_eq!(policy3.name(), "random"); + + // Check mappings + let mappings = registry.get_all_mappings(); + assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware"); + assert_eq!(mappings.get("gpt-4").unwrap(), "random"); + + // Check worker counts + let counts = registry.get_worker_counts(); + assert_eq!(*counts.get("llama-3").unwrap(), 2); + assert_eq!(*counts.get("gpt-4").unwrap(), 1); + } + + #[test] + fn test_policy_registry_cleanup() { + let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); + + // Add workers + registry.on_worker_added("llama-3", Some("cache_aware")); + registry.on_worker_added("llama-3", None); + assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&2)); + + // Remove one worker - policy should remain + registry.on_worker_removed("llama-3"); + assert!(registry.get_policy("llama-3").is_some()); + assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&1)); + + // Remove last worker - policy should be cleaned up + registry.on_worker_removed("llama-3"); + assert!(registry.get_policy("llama-3").is_none()); + assert_eq!(registry.get_worker_counts().get("llama-3"), None); + } + + #[test] + fn test_default_policy() { + let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); + + // No hint, no template - uses default + let policy = registry.on_worker_added("unknown-model", None); + assert_eq!(policy.name(), "round_robin"); + + // Get default directly + let default = registry.get_default_policy(); + assert_eq!(default.name(), "round_robin"); + } +} diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs index fcb60233f..1b4087224 100644 --- a/sgl-router/src/policies/round_robin.rs +++ b/sgl-router/src/policies/round_robin.rs @@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use crate::core::Worker; use crate::metrics::RouterMetrics; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; /// Round-robin selection policy /// @@ -24,7 +25,7 @@ impl RoundRobinPolicy { impl LoadBalancingPolicy for RoundRobinPolicy { fn select_worker( &self, - workers: &[Box], + workers: &[Arc], _request_text: Option<&str>, ) -> Option { let healthy_indices = get_healthy_worker_indices(workers); @@ -64,16 +65,16 @@ mod tests { #[test] fn test_round_robin_selection() { let policy = RoundRobinPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w3:8000".to_string(), WorkerType::Regular, )), @@ -90,16 +91,16 @@ mod tests { #[test] fn test_round_robin_with_unhealthy_workers() { let policy = RoundRobinPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w3:8000".to_string(), WorkerType::Regular, )), @@ -118,12 +119,12 @@ mod tests { #[test] fn test_round_robin_reset() { let policy = RoundRobinPolicy::new(); - let workers: Vec> = vec![ - Box::new(BasicWorker::new( + let workers: Vec> = vec![ + Arc::new(BasicWorker::new( "http://w1:8000".to_string(), WorkerType::Regular, )), - Box::new(BasicWorker::new( + Arc::new(BasicWorker::new( "http://w2:8000".to_string(), WorkerType::Regular, )), diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs index 5243c645f..7359a3d2e 100644 --- a/sgl-router/src/protocols/mod.rs +++ b/sgl-router/src/protocols/mod.rs @@ -3,3 +3,4 @@ pub mod spec; pub mod validation; +pub mod worker_spec; diff --git a/sgl-router/src/protocols/worker_spec.rs b/sgl-router/src/protocols/worker_spec.rs new file mode 100644 index 000000000..f6f8021a1 --- /dev/null +++ b/sgl-router/src/protocols/worker_spec.rs @@ -0,0 +1,198 @@ +//! Worker management API specifications +//! +//! Defines the request/response structures for worker management endpoints + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Worker configuration for API requests +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WorkerConfigRequest { + /// Worker URL (required) + pub url: String, + + /// Model ID (optional, will query from server if not provided) + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, + + /// Worker priority (optional, default: 50, higher = preferred) + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + /// Worker cost factor (optional, default: 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + + /// Worker type (optional: "regular", "prefill", "decode") + #[serde(skip_serializing_if = "Option::is_none")] + pub worker_type: Option, + + /// Bootstrap port for prefill workers (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_port: Option, + + // gRPC-specific configuration (optional, ignored in HTTP mode) + /// Tokenizer path for gRPC mode + #[serde(skip_serializing_if = "Option::is_none")] + pub tokenizer_path: Option, + + /// Reasoning parser type for gRPC mode + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_parser: Option, + + /// Tool parser type for gRPC mode + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_parser: Option, + + /// Chat template for gRPC mode + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_template: Option, + + /// Additional labels (optional) + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub labels: HashMap, +} + +/// Worker information for API responses +#[derive(Debug, Clone, Serialize)] +pub struct WorkerInfo { + /// Worker unique identifier + pub id: String, + + /// Worker URL + pub url: String, + + /// Model ID this worker serves + pub model_id: String, + + /// Worker priority + pub priority: u32, + + /// Worker cost factor + pub cost: f32, + + /// Worker type + pub worker_type: String, + + /// Whether the worker is healthy + pub is_healthy: bool, + + /// Current load on the worker + pub load: usize, + + /// Connection mode (http or grpc) + pub connection_mode: String, + + // gRPC-specific fields (None for HTTP workers) + #[serde(skip_serializing_if = "Option::is_none")] + pub tokenizer_path: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_parser: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_parser: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_template: Option, + + /// Additional metadata + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub metadata: HashMap, +} + +/// Worker list response +#[derive(Debug, Clone, Serialize)] +pub struct WorkerListResponse { + /// List of workers + pub workers: Vec, + + /// Total count + pub total: usize, + + /// Statistics + pub stats: WorkerStats, +} + +/// Worker statistics +#[derive(Debug, Clone, Serialize)] +pub struct WorkerStats { + pub total_workers: usize, + pub healthy_workers: usize, + pub total_models: usize, + pub total_load: usize, + pub by_type: WorkerTypeStats, +} + +/// Worker statistics by type +#[derive(Debug, Clone, Serialize)] +pub struct WorkerTypeStats { + pub regular: usize, + pub prefill: usize, + pub decode: usize, +} + +/// Worker update request +#[derive(Debug, Clone, Deserialize)] +pub struct WorkerUpdateRequest { + /// Update priority + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + /// Update cost + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + + /// Update labels + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option>, +} + +/// Generic API response +#[derive(Debug, Clone, Serialize)] +pub struct WorkerApiResponse { + pub success: bool, + pub message: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub worker: Option, +} + +/// Error response +#[derive(Debug, Clone, Serialize)] +pub struct WorkerErrorResponse { + pub error: String, + pub code: String, +} + +/// Server info response from /get_server_info endpoint +#[derive(Debug, Clone, Deserialize)] +pub struct ServerInfo { + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub model_path: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub worker_type: Option, + + // gRPC-specific + #[serde(skip_serializing_if = "Option::is_none")] + pub tokenizer_path: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_parser: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_parser: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_template: Option, +} diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index d1bdc0fce..9fec8be13 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -15,11 +15,6 @@ pub struct RouterFactory; impl RouterFactory { /// Create a router instance from application context pub async fn create_router(ctx: &Arc) -> Result, String> { - // Check if IGW mode is enabled - if ctx.router_config.enable_igw { - return Self::create_igw_router(ctx).await; - } - // Check connection mode and route to appropriate implementation match ctx.router_config.connection_mode { ConnectionMode::Grpc => { @@ -53,8 +48,7 @@ impl RouterFactory { // Route to HTTP implementation based on routing mode match &ctx.router_config.mode { RoutingMode::Regular { worker_urls } => { - Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx) - .await + Self::create_regular_router(worker_urls, ctx).await } RoutingMode::PrefillDecode { prefill_urls, @@ -80,23 +74,19 @@ impl RouterFactory { } } - /// Create a regular router with injected policy - async fn create_regular_router( + /// Create a regular router + pub async fn create_regular_router( worker_urls: &[String], - policy_config: &PolicyConfig, ctx: &Arc, ) -> Result, String> { - // Create policy - let policy = PolicyFactory::create_from_config(policy_config); - - // Create regular router with injected policy and context - let router = Router::new(worker_urls.to_vec(), policy, ctx).await?; + // Create regular router with context + let router = Router::new(worker_urls.to_vec(), ctx).await?; Ok(Box::new(router)) } /// Create a PD router with injected policy - async fn create_pd_router( + pub async fn create_pd_router( prefill_urls: &[(String, Option)], decode_urls: &[String], prefill_policy_config: Option<&PolicyConfig>, @@ -104,21 +94,18 @@ impl RouterFactory { main_policy_config: &PolicyConfig, ctx: &Arc, ) -> Result, String> { - // Create policies - use specific policies if provided, otherwise fall back to main policy + // Initialize policies in PolicyRegistry - use specific policies if provided, otherwise fall back to main policy let prefill_policy = PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config)); let decode_policy = PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); - // Create PD router with separate policies and context - let router = PDRouter::new( - prefill_urls.to_vec(), - decode_urls.to_vec(), - prefill_policy, - decode_policy, - ctx, - ) - .await?; + // Set the prefill and decode policies in the registry + ctx.policy_registry.set_prefill_policy(prefill_policy); + ctx.policy_registry.set_decode_policy(decode_policy); + + // Create PD router with context (policies are in PolicyRegistry) + let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?; Ok(Box::new(router)) } @@ -186,10 +173,4 @@ impl RouterFactory { Ok(Box::new(router)) } - - /// Create an IGW router (placeholder for future implementation) - async fn create_igw_router(_ctx: &Arc) -> Result, String> { - // For now, return an error indicating IGW is not yet implemented - Err("IGW mode is not yet implemented".to_string()) - } } diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index a0a3c7911..3efb9ca87 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -27,9 +27,9 @@ use tracing::{info, warn}; #[allow(dead_code)] // Fields will be used once implementation is complete pub struct GrpcPDRouter { /// Prefill worker connections - prefill_workers: Arc>>>, + prefill_workers: Arc>>>, /// Decode worker connections - decode_workers: Arc>>>, + decode_workers: Arc>>>, /// gRPC clients for prefill workers prefill_grpc_clients: Arc>>, /// gRPC clients for decode workers @@ -127,7 +127,7 @@ impl GrpcPDRouter { } // Create Prefill Worker trait objects with gRPC connection mode - let prefill_workers: Vec> = prefill_urls + let prefill_workers: Vec> = prefill_urls .iter() .map(|(url, bootstrap_port)| { let worker = BasicWorker::with_connection_mode( @@ -147,12 +147,12 @@ impl GrpcPDRouter { failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, }); - Box::new(worker) as Box + Arc::new(worker) as Arc }) .collect(); // Create Decode Worker trait objects with gRPC connection mode - let decode_workers: Vec> = decode_urls + let decode_workers: Vec> = decode_urls .iter() .map(|url| { let worker = BasicWorker::with_connection_mode( @@ -168,7 +168,7 @@ impl GrpcPDRouter { failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, }); - Box::new(worker) as Box + Arc::new(worker) as Arc }) .collect(); @@ -269,6 +269,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::GenerateRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -277,6 +278,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::ChatCompletionRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -285,6 +287,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::CompletionRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -293,6 +296,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::ResponsesRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -305,6 +309,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::RerankRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 245513b37..cb4bab412 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -27,7 +27,7 @@ use tracing::{info, warn}; #[allow(dead_code)] // Fields will be used once implementation is complete pub struct GrpcRouter { /// Worker connections - workers: Arc>>>, + workers: Arc>>>, /// gRPC clients for each worker grpc_clients: Arc>>, /// Load balancing policy @@ -103,7 +103,7 @@ impl GrpcRouter { } // Create Worker trait objects with gRPC connection mode - let mut workers: Vec> = Vec::new(); + let mut workers: Vec> = Vec::new(); // Move clients from the HashMap to the workers for url in &worker_urls { @@ -123,7 +123,7 @@ impl GrpcRouter { }) .with_grpc_client(client); - workers.push(Box::new(worker) as Box); + workers.push(Arc::new(worker) as Arc); } else { warn!("No gRPC client for worker {}, skipping", url); } @@ -202,6 +202,7 @@ impl RouterTrait for GrpcRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::GenerateRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -210,6 +211,7 @@ impl RouterTrait for GrpcRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::ChatCompletionRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -218,6 +220,7 @@ impl RouterTrait for GrpcRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::CompletionRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -226,6 +229,7 @@ impl RouterTrait for GrpcRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::ResponsesRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -238,6 +242,7 @@ impl RouterTrait for GrpcRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::RerankRequest, + _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 0f5a56974..e75cb794a 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -186,6 +186,7 @@ impl super::super::RouterTrait for OpenAIRouter { &self, _headers: Option<&HeaderMap>, _body: &GenerateRequest, + _model_id: Option<&str>, ) -> Response { // Generate endpoint is SGLang-specific, not supported for OpenAI backend ( @@ -199,6 +200,7 @@ impl super::super::RouterTrait for OpenAIRouter { &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, + _model_id: Option<&str>, ) -> Response { if !self.circuit_breaker.can_execute() { return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response(); @@ -326,6 +328,7 @@ impl super::super::RouterTrait for OpenAIRouter { &self, _headers: Option<&HeaderMap>, _body: &CompletionRequest, + _model_id: Option<&str>, ) -> Response { // Completion endpoint not implemented for OpenAI backend ( @@ -339,6 +342,7 @@ impl super::super::RouterTrait for OpenAIRouter { &self, _headers: Option<&HeaderMap>, _body: &crate::protocols::spec::ResponsesRequest, + _model_id: Option<&str>, ) -> Response { ( StatusCode::NOT_IMPLEMENTED, @@ -383,7 +387,12 @@ impl super::super::RouterTrait for OpenAIRouter { .into_response() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: &RerankRequest) -> Response { + async fn route_rerank( + &self, + _headers: Option<&HeaderMap>, + _body: &RerankRequest, + _model_id: Option<&str>, + ) -> Response { ( StatusCode::NOT_IMPLEMENTED, "Rerank endpoint not implemented for OpenAI backend", diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index af4d605f0..4f31cc225 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -3,11 +3,11 @@ use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, - RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType, + is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, + WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; -use crate::policies::LoadBalancingPolicy; +use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, StringOrArray, UserMessageContent, @@ -27,7 +27,7 @@ use reqwest::Client; use serde::Serialize; use serde_json::{json, Value}; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -35,10 +35,8 @@ use tracing::{debug, error, info, warn}; #[derive(Debug)] pub struct PDRouter { - pub prefill_workers: Arc>>>, - pub decode_workers: Arc>>>, - pub prefill_policy: Arc, - pub decode_policy: Arc, + pub worker_registry: Arc, + pub policy_registry: Arc, pub worker_startup_timeout_secs: u64, pub worker_startup_check_interval_secs: u64, pub worker_loads: Arc>>, @@ -48,25 +46,22 @@ pub struct PDRouter { pub prefill_client: Client, pub retry_config: RetryConfig, pub circuit_breaker_config: CircuitBreakerConfig, - _prefill_health_checker: Option, - _decode_health_checker: Option, // Channel for sending prefill responses to background workers for draining prefill_drain_tx: mpsc::Sender, } // Request context for PD router operations #[derive(Clone)] -struct PDRequestContext { +struct PDRequestContext<'a> { route: &'static str, batch_size: Option, is_stream: bool, return_logprob: bool, request_text: Option, + model_id: Option<&'a str>, } impl PDRouter { - // Dynamic worker management methods for service discovery - // Private helper method to perform health check on a new server async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> { crate::routers::http::router::Router::wait_for_healthy_workers( @@ -83,24 +78,16 @@ impl PDRouter { // Generic helper for processing all workers with an endpoint async fn process_workers( &self, - workers: &RwLock>>, + worker_type_enum: WorkerType, worker_type: &str, endpoint: &str, ) -> (Vec, Vec) { let mut results = Vec::new(); let mut errors = Vec::new(); - // Get worker URLs first to avoid holding lock across await - let urls = match workers.read() { - Ok(workers) => workers - .iter() - .map(|w| w.url().to_string()) - .collect::>(), - Err(_) => { - errors.push(format!("Failed to access {} workers", worker_type)); - Vec::new() - } - }; + // Get workers from registry based on type + let workers = self.worker_registry.get_by_type(&worker_type_enum); + let urls: Vec = workers.iter().map(|w| w.url().to_string()).collect(); // Process each worker for worker_url in urls { @@ -126,101 +113,98 @@ impl PDRouter { (results, errors) } - // Helper to get worker URLs from a worker collection - fn get_worker_urls( - workers: &RwLock>>, - worker_type: &str, - ) -> Result, String> { - workers - .read() - .map(|workers| { - workers - .iter() - .map(|w| w.url().to_string()) - .collect::>() - }) - .map_err(|_| format!("Failed to access {} workers", worker_type)) + // Helper to get prefill worker URLs + fn get_prefill_worker_urls(&self) -> Vec { + self.worker_registry + .get_prefill_workers() + .iter() + .map(|w| w.url().to_string()) + .collect() } - // Generic helper for proxying requests to the first worker - async fn proxy_to_first_worker( + // Helper to get decode worker URLs + fn get_decode_worker_urls(&self) -> Vec { + self.worker_registry + .get_decode_workers() + .iter() + .map(|w| w.url().to_string()) + .collect() + } + + // Helper for proxying requests to the first prefill worker + async fn proxy_to_first_prefill_worker( &self, - workers: &RwLock>>, endpoint: &str, - worker_type: &str, headers: Option>, ) -> Response { - // Get first worker URL to avoid holding lock across await - let first_worker_url = match workers.read() { - Ok(workers) => workers.first().map(|w| w.url().to_string()), - Err(_) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to access {} workers", worker_type), - ) - .into_response(); - } - }; + let workers = self.worker_registry.get_prefill_workers(); + let first_worker_url = workers.first().map(|w| w.url().to_string()); if let Some(worker_url) = first_worker_url { - let url = format!("{}/{}", worker_url, endpoint); - let mut request_builder = self.client.get(&url); - - // Add headers if provided - if let Some(headers) = headers { - for (name, value) in headers { - request_builder = request_builder.header(name, value); - } - } - - match request_builder.send().await { - Ok(res) if res.status().is_success() => { - let response_headers = header_utils::preserve_response_headers(res.headers()); - - match res.bytes().await { - Ok(body) => { - let mut response = Response::new(axum::body::Body::from(body)); - *response.status_mut() = StatusCode::OK; - *response.headers_mut() = response_headers; - response - } - Err(e) => { - error!("Failed to read response body: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response() - } - } - } - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - ( - status, - format!("{} server returned status: {}", worker_type, res.status()), - ) - .into_response() - } - Err(e) => { - error!("Failed to proxy request to {} server: {}", worker_type, e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to proxy request: {}", e), - ) - .into_response() - } - } + self.proxy_to_worker(worker_url, endpoint, headers).await } else { ( StatusCode::SERVICE_UNAVAILABLE, - format!("No {} servers available", worker_type), + "No prefill servers available".to_string(), ) .into_response() } } + // Generic helper for proxying to a specific worker + async fn proxy_to_worker( + &self, + worker_url: String, + endpoint: &str, + headers: Option>, + ) -> Response { + let url = format!("{}/{}", worker_url, endpoint); + let mut request_builder = self.client.get(&url); + + // Add headers if provided + if let Some(headers) = headers { + for (name, value) in headers { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { + Ok(res) if res.status().is_success() => { + let response_headers = header_utils::preserve_response_headers(res.headers()); + + match res.bytes().await { + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = StatusCode::OK; + *response.headers_mut() = response_headers; + response + } + Err(e) => { + error!("Failed to read response body: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response() + } + } + } + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + (status, format!("{} server returned status: ", res.status())).into_response() + } + Err(e) => { + error!("Failed to proxy request server: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to proxy request: {}", e), + ) + .into_response() + } + } + } + pub async fn add_prefill_server( &self, url: String, @@ -229,36 +213,37 @@ impl PDRouter { // Wait for the new server to be healthy self.wait_for_server_health(&url).await?; + // Check if already exists + if self.worker_registry.get_by_url(&url).is_some() { + return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); + } + // Create Worker for the new prefill server with circuit breaker configuration + // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint let worker = WorkerFactory::create_prefill_with_config( url.clone(), bootstrap_port, self.circuit_breaker_config.clone(), ); - // Add to prefill workers list - let mut workers = self - .prefill_workers - .write() - .map_err(|_| PDRouterError::LockError { - operation: "prefill_workers write".to_string(), - })?; + let worker_arc: Arc = Arc::from(worker); - // Check if already exists - if workers.iter().any(|w| w.url() == url) { - return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); - } + // Register the worker in the registry + self.worker_registry.register(worker_arc.clone()); - workers.push(worker); + // Notify PolicyRegistry about the new worker + let model_id = worker_arc.model_id(); + let policy = self.policy_registry.on_worker_added(model_id, None); - // Update cache-aware policy if applicable - drop(workers); // Release write lock - if let Some(cache_policy) = self - .prefill_policy - .as_any() - .downcast_ref::() - { - cache_policy.add_worker(&url); + // If this is a cache-aware policy, update it with all workers for this model + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + let model_workers = self.worker_registry.get_by_model_fast(model_id); + cache_aware.init_workers(&model_workers); + } } info!("Added prefill server: {}", url); @@ -269,35 +254,36 @@ impl PDRouter { // Wait for the new server to be healthy self.wait_for_server_health(&url).await?; + // Check if already exists + if self.worker_registry.get_by_url(&url).is_some() { + return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); + } + // Create Worker for the new decode server with circuit breaker configuration + // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint let worker = WorkerFactory::create_decode_with_config( url.clone(), self.circuit_breaker_config.clone(), ); - // Add to decode workers list - let mut workers = self - .decode_workers - .write() - .map_err(|_| PDRouterError::LockError { - operation: "decode_workers write".to_string(), - })?; + let worker_arc: Arc = Arc::from(worker); - // Check if already exists - if workers.iter().any(|w| w.url() == url) { - return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); - } + // Register the worker in the registry + self.worker_registry.register(worker_arc.clone()); - workers.push(worker); + // Notify PolicyRegistry about the new worker + let model_id = worker_arc.model_id(); + let policy = self.policy_registry.on_worker_added(model_id, None); - // Update cache-aware policy if applicable - drop(workers); // Release write lock - if let Some(cache_policy) = self - .decode_policy - .as_any() - .downcast_ref::() - { - cache_policy.add_worker(&url); + // If this is a cache-aware policy, update it with all workers for this model + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + let model_workers = self.worker_registry.get_by_model_fast(model_id); + cache_aware.init_workers(&model_workers); + } } info!("Added decode server: {}", url); @@ -305,73 +291,91 @@ impl PDRouter { } pub async fn remove_prefill_server(&self, url: &str) -> Result { - let mut workers = self - .prefill_workers - .write() - .map_err(|_| PDRouterError::LockError { - operation: "prefill_workers write".to_string(), - })?; + // Check if worker exists and get model_id + let model_id = match self.worker_registry.get_by_url(url) { + Some(worker) => worker.model_id().to_string(), + None => { + return Err(PDRouterError::WorkerNotFound { + url: url.to_string(), + }); + } + }; - // Find and remove the server - let initial_len = workers.len(); - workers.retain(|w| w.url() != url); + // Remove from registry + let removed = self.worker_registry.remove_by_url(url); - if workers.len() == initial_len { - return Err(PDRouterError::WorkerNotFound { + if removed.is_some() { + // Notify PolicyRegistry about the removed worker + self.policy_registry.on_worker_removed(&model_id); + + // Get the policy for this model to update cache-aware if needed + if let Some(policy) = self.policy_registry.get_policy(&model_id) { + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker_by_url(url); + } + } + } + } + + if removed.is_some() { + info!("Removed prefill server: {}", url); + Ok(format!("Successfully removed prefill server: {}", url)) + } else { + Err(PDRouterError::WorkerNotFound { url: url.to_string(), - }); + }) } - - // Remove from cache-aware policy if applicable - if let Some(cache_policy) = self - .prefill_policy - .as_any() - .downcast_ref::() - { - cache_policy.remove_worker(url); - } - - info!("Removed prefill server: {}", url); - Ok(format!("Successfully removed prefill server: {}", url)) } pub async fn remove_decode_server(&self, url: &str) -> Result { - let mut workers = self - .decode_workers - .write() - .map_err(|_| PDRouterError::LockError { - operation: "decode_workers write".to_string(), - })?; + // Check if worker exists and get model_id + let model_id = match self.worker_registry.get_by_url(url) { + Some(worker) => worker.model_id().to_string(), + None => { + return Err(PDRouterError::WorkerNotFound { + url: url.to_string(), + }); + } + }; - // Find and remove the server - let initial_len = workers.len(); - workers.retain(|w| w.url() != url); + // Remove from registry + let removed = self.worker_registry.remove_by_url(url); - if workers.len() == initial_len { - return Err(PDRouterError::WorkerNotFound { + if removed.is_some() { + // Notify PolicyRegistry about the removed worker + self.policy_registry.on_worker_removed(&model_id); + + // Get the policy for this model to update cache-aware if needed + if let Some(policy) = self.policy_registry.get_policy(&model_id) { + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker_by_url(url); + } + } + } + } + + if removed.is_some() { + info!("Removed decode server: {}", url); + Ok(format!("Successfully removed decode server: {}", url)) + } else { + Err(PDRouterError::WorkerNotFound { url: url.to_string(), - }); + }) } - - // Remove from cache-aware policy if applicable - if let Some(cache_policy) = self - .decode_policy - .as_any() - .downcast_ref::() - { - cache_policy.remove_worker(url); - } - - info!("Removed decode server: {}", url); - Ok(format!("Successfully removed decode server: {}", url)) } #[allow(clippy::too_many_arguments)] pub async fn new( prefill_urls: Vec<(String, Option)>, decode_urls: Vec, - prefill_policy: Arc, - decode_policy: Arc, ctx: &Arc, ) -> Result { // Convert config CircuitBreakerConfig to core CircuitBreakerConfig @@ -383,16 +387,28 @@ impl PDRouter { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Convert URLs to Worker trait objects with health check config - let prefill_workers: Vec> = prefill_urls - .into_iter() - .map(|(url, port)| { - let worker = BasicWorker::new( - url, - WorkerType::Prefill { - bootstrap_port: port, - }, - ) + // Register prefill workers in the registry + for (url, port) in prefill_urls { + let worker = BasicWorker::new( + url, + WorkerType::Prefill { + bootstrap_port: port, + }, + ) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }); + ctx.worker_registry.register(Arc::new(worker)); + } + + // Register decode workers in the registry + for url in decode_urls { + let worker = BasicWorker::new(url, WorkerType::Decode) .with_circuit_breaker_config(core_cb_config.clone()) .with_health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs, @@ -401,30 +417,13 @@ impl PDRouter { failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, }); - Box::new(worker) as Box - }) - .collect(); + ctx.worker_registry.register(Arc::new(worker)); + } - let decode_workers: Vec> = decode_urls - .into_iter() - .map(|url| { - let worker = BasicWorker::new(url, WorkerType::Decode) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }); - Box::new(worker) as Box - }) - .collect(); - - // Wait for PD workers to be healthy (skip if empty - for service discovery mode) - let all_urls: Vec = prefill_workers + // Get all workers from registry for health check + let all_workers = ctx.worker_registry.get_all(); + let all_urls: Vec = all_workers .iter() - .chain(decode_workers.iter()) .map(|worker| worker.url().to_string()) .collect(); if !all_urls.is_empty() { @@ -436,25 +435,19 @@ impl PDRouter { .await?; } - // Initialize cache-aware policies with workers - if let Some(cache_policy) = prefill_policy - .as_any() - .downcast_ref::() - { - cache_policy.init_workers(&prefill_workers); - } - - if let Some(cache_policy) = decode_policy - .as_any() - .downcast_ref::() - { - cache_policy.init_workers(&decode_workers); - } + // Initialize cache-aware policies with workers from registry + // Note: We need to get workers by type and convert to Box for CacheAwarePolicy + // This is a temporary workaround until CacheAwarePolicy is updated to work with Arc + // TODO: Update CacheAwarePolicy to accept Arc instead of Box // Set up background load monitoring for power-of-two selection let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); + // Get policies from registry to check if we need load monitoring + let prefill_policy = ctx.policy_registry.get_prefill_policy(); + let decode_policy = ctx.policy_registry.get_decode_policy(); + let load_monitor_handle = if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { let monitor_urls = all_urls.clone(); @@ -478,18 +471,8 @@ impl PDRouter { None }; - let prefill_workers = Arc::new(RwLock::new(prefill_workers)); - let decode_workers = Arc::new(RwLock::new(decode_workers)); - - // Start health checkers for both worker pools - let prefill_health_checker = crate::core::start_health_checker( - Arc::clone(&prefill_workers), - ctx.router_config.health_check.check_interval_secs, - ); - let decode_health_checker = crate::core::start_health_checker( - Arc::clone(&decode_workers), - ctx.router_config.health_check.check_interval_secs, - ); + // Note: Health checking is now handled centrally by RouterManager + // Individual routers no longer need to manage health checkers // Build a dedicated prefill client for fire-and-forget semantics let prefill_client = reqwest::Client::builder() @@ -570,10 +553,8 @@ impl PDRouter { }); Ok(PDRouter { - prefill_workers, - decode_workers, - prefill_policy, - decode_policy, + worker_registry: Arc::clone(&ctx.worker_registry), + policy_registry: Arc::clone(&ctx.policy_registry), worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs, worker_startup_check_interval_secs: ctx .router_config @@ -585,8 +566,6 @@ impl PDRouter { prefill_drain_tx, retry_config: ctx.router_config.effective_retry_config(), circuit_breaker_config: core_cb_config, - _prefill_health_checker: Some(prefill_health_checker), - _decode_health_checker: Some(decode_health_checker), }) } @@ -721,7 +700,7 @@ impl PDRouter { &self, headers: Option<&HeaderMap>, original_request: &T, - context: PDRequestContext, + context: PDRequestContext<'_>, ) -> Response { let start_time = Instant::now(); @@ -736,14 +715,16 @@ impl PDRouter { let context = context.clone(); async move { // Select workers fresh for each attempt - let (prefill, decode) = - match self.select_pd_pair(context.request_text.as_deref()).await { - Ok(pair) => pair, - Err(e) => { - RouterMetrics::record_pd_error("server_selection"); - return Self::handle_server_selection_error(e); - } - }; + let (prefill, decode) = match self + .select_pd_pair(context.request_text.as_deref(), context.model_id) + .await + { + Ok(pair) => pair, + Err(e) => { + RouterMetrics::record_pd_error("server_selection"); + return Self::handle_server_selection_error(e); + } + }; debug!( "PD retry attempt {} using prefill={} decode={}", @@ -806,7 +787,7 @@ impl PDRouter { async fn handle_decode_error_response( &self, res: reqwest::Response, - context: &PDRequestContext, + context: &PDRequestContext<'_>, prefill: &dyn Worker, decode: &dyn Worker, ) -> Response { @@ -859,7 +840,7 @@ impl PDRouter { &self, headers: Option<&HeaderMap>, json_request: Value, - context: PDRequestContext, + context: PDRequestContext<'_>, prefill: &dyn Worker, decode: &dyn Worker, start_time: Instant, @@ -1131,35 +1112,56 @@ impl PDRouter { // Check if either prefill or decode policy needs request text fn policies_need_request_text(&self) -> bool { - self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text() + // Check both prefill and decode policies + let prefill_policy = self.policy_registry.get_prefill_policy(); + let decode_policy = self.policy_registry.get_decode_policy(); + prefill_policy.needs_request_text() || decode_policy.needs_request_text() } // Select a pair of prefill and decode servers considering circuit breaker state async fn select_pd_pair( &self, request_text: Option<&str>, - ) -> Result<(Box, Box), String> { - // Get read locks for both worker lists - let prefill_workers = self - .prefill_workers - .read() - .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?; - let decode_workers = self - .decode_workers - .read() - .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?; + model_id: Option<&str>, + ) -> Result<(Arc, Arc), String> { + // Get workers from registry - filter by model if provided + let prefill_workers = if let Some(model) = model_id { + // Get model-specific workers and filter for prefill type + self.worker_registry + .get_by_model_fast(model) + .into_iter() + .filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. })) + .collect() + } else { + self.worker_registry.get_prefill_workers() + }; + + let decode_workers = if let Some(model) = model_id { + // Get model-specific workers and filter for decode type + self.worker_registry + .get_by_model_fast(model) + .into_iter() + .filter(|w| matches!(w.worker_type(), WorkerType::Decode)) + .collect() + } else { + self.worker_registry.get_decode_workers() + }; // Select workers using helper function - let prefill = Self::pick_worker_by_policy( + // Use separate policies for prefill and decode to avoid counter conflicts + let prefill_policy = self.policy_registry.get_prefill_policy(); + let decode_policy = self.policy_registry.get_decode_policy(); + + let prefill = Self::pick_worker_by_policy_arc( &prefill_workers, - &*self.prefill_policy, + &*prefill_policy, request_text, "prefill", )?; - let decode = Self::pick_worker_by_policy( + let decode = Self::pick_worker_by_policy_arc( &decode_workers, - &*self.decode_policy, + &*decode_policy, request_text, "decode", )?; @@ -1167,13 +1169,13 @@ impl PDRouter { Ok((prefill, decode)) } - // Helper function to select a worker using the policy - fn pick_worker_by_policy( - workers: &[Box], + // Helper function to select a worker using the policy (Arc version) + fn pick_worker_by_policy_arc( + workers: &[Arc], policy: &dyn LoadBalancingPolicy, request_text: Option<&str>, worker_type: &str, - ) -> Result, String> { + ) -> Result, String> { // Check if we have any workers if workers.is_empty() { return Err(format!( @@ -1183,10 +1185,10 @@ impl PDRouter { } // Filter available workers (healthy + circuit breaker not open) - let available_workers: Vec> = workers + let available_workers: Vec> = workers .iter() .filter(|w| w.is_available()) - .map(|w| w.clone_worker()) + .cloned() .collect(); if available_workers.is_empty() { @@ -1196,11 +1198,19 @@ impl PDRouter { )); } - // Let policy select from available workers only - match policy.select_worker(&available_workers, request_text) { - Some(idx) => Ok(available_workers[idx].clone_worker()), - None => Err(format!("Policy could not select a {} worker", worker_type)), - } + // Let policy select from available workers (no conversion needed now!) + let selected_idx = policy + .select_worker(&available_workers, request_text) + .ok_or_else(|| { + format!( + "Policy {} failed to select a {} worker", + policy.name(), + worker_type + ) + })?; + + // Return the selected Arc worker + Ok(available_workers[selected_idx].clone()) } // Background task to monitor worker loads with shared client @@ -1272,9 +1282,8 @@ impl PDRouter { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - // Clone the worker collections for the spawned task - let prefill_workers = self.prefill_workers.clone(); - let decode_workers = self.decode_workers.clone(); + // Clone the registry for the spawned task + let registry = self.worker_registry.clone(); tokio::spawn(async move { // Use a flag to track whether stream completed successfully @@ -1321,31 +1330,21 @@ impl PDRouter { // Always decrement load after streaming (either completes or errors) // Find and decrement prefill worker - if let Ok(prefill_workers_guard) = prefill_workers.read() { - for worker in prefill_workers_guard.iter() { - if worker.url() == prefill_url.as_str() { - worker.decrement_load(); - debug!( - "Decremented load for prefill worker: {} (stream_completed: {})", - prefill_url, stream_completed - ); - break; - } - } + if let Some(worker) = registry.get_by_url(&prefill_url) { + worker.decrement_load(); + debug!( + "Decremented load for prefill worker: {} (stream_completed: {})", + prefill_url, stream_completed + ); } // Find and decrement decode worker - if let Ok(decode_workers_guard) = decode_workers.read() { - for worker in decode_workers_guard.iter() { - if worker.url() == decode_url_str.as_str() { - worker.decrement_load(); - debug!( - "Decremented load for decode worker: {} (stream_completed: {})", - decode_url_str, stream_completed - ); - break; - } - } + if let Some(worker) = registry.get_by_url(&decode_url_str) { + worker.decrement_load(); + debug!( + "Decremented load for decode worker: {} (stream_completed: {})", + decode_url_str, stream_completed + ); } }); @@ -1626,42 +1625,24 @@ impl WorkerManagement for PDRouter { } fn remove_worker(&self, worker_url: &str) { - // For PD router, we would need to know if it's a prefill or decode server - // For now, try both - if let Ok(mut workers) = self.prefill_workers.write() { - if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { - workers.remove(index); - info!("Removed prefill worker: {}", worker_url); - return; - } - } - - if let Ok(mut workers) = self.decode_workers.write() { - if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { - workers.remove(index); - info!("Removed decode worker: {}", worker_url); + // Remove from registry + if let Some(worker) = self.worker_registry.remove_by_url(worker_url) { + match worker.worker_type() { + WorkerType::Prefill { .. } => { + info!("Removed prefill worker: {}", worker_url); + } + WorkerType::Decode => { + info!("Removed decode worker: {}", worker_url); + } + _ => { + info!("Removed worker: {}", worker_url); + } } } } fn get_worker_urls(&self) -> Vec { - let mut urls = Vec::new(); - - // Add prefill worker URLs - if let Ok(workers) = self.prefill_workers.read() { - for worker in workers.iter() { - urls.push(worker.url().to_string()); - } - } - - // Add decode worker URLs - if let Ok(workers) = self.decode_workers.read() { - for worker in workers.iter() { - urls.push(worker.url().to_string()); - } - } - - urls + self.worker_registry.get_all_urls() } } @@ -1677,19 +1658,16 @@ impl RouterTrait for PDRouter { let mut all_healthy = true; let mut unhealthy_servers = Vec::new(); - // Check prefill servers - for worker in self.prefill_workers.read().unwrap().iter() { + // Check all workers + for worker in self.worker_registry.get_all() { if !worker.is_healthy() { all_healthy = false; - unhealthy_servers.push(format!("Prefill: {}", worker.url())); - } - } - - // Check decode servers - for worker in self.decode_workers.read().unwrap().iter() { - if !worker.is_healthy() { - all_healthy = false; - unhealthy_servers.push(format!("Decode: {}", worker.url())); + let worker_type = match worker.worker_type() { + WorkerType::Prefill { .. } => "Prefill", + WorkerType::Decode => "Decode", + _ => "Worker", + }; + unhealthy_servers.push(format!("{}: {}", worker_type, worker.url())); } } @@ -1709,7 +1687,7 @@ impl RouterTrait for PDRouter { // Note: This endpoint actually causes the model to generate tokens, so we only test one pair // Select a random worker pair using the policy - let (prefill, decode) = match self.select_pd_pair(None).await { + let (prefill, decode) = match self.select_pd_pair(None, None).await { Ok(pair) => pair, Err(e) => { return ( @@ -1789,7 +1767,7 @@ impl RouterTrait for PDRouter { async fn get_server_info(&self, _req: Request) -> Response { // Get info from the first decode server to match sglang's server info format // Note: We use decode workers for server info to match expected format - self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None) + self.proxy_to_first_prefill_worker("get_server_info", None) .await } @@ -1798,7 +1776,7 @@ impl RouterTrait for PDRouter { let headers = header_utils::copy_request_headers(&req); // Proxy to first prefill worker - self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers)) + self.proxy_to_first_prefill_worker("v1/models", Some(headers)) .await } @@ -1807,19 +1785,15 @@ impl RouterTrait for PDRouter { let headers = header_utils::copy_request_headers(&req); // Proxy to first prefill worker - self.proxy_to_first_worker( - &self.prefill_workers, - "get_model_info", - "prefill", - Some(headers), - ) - .await + self.proxy_to_first_prefill_worker("get_model_info", Some(headers)) + .await } async fn route_generate( &self, headers: Option<&HeaderMap>, body: &GenerateRequest, + model_id: Option<&str>, ) -> Response { // Extract parameters let is_stream = body.stream; @@ -1850,6 +1824,7 @@ impl RouterTrait for PDRouter { is_stream, return_logprob, request_text, + model_id, }; // Execute with retry and bootstrap injection @@ -1860,6 +1835,7 @@ impl RouterTrait for PDRouter { &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, + model_id: Option<&str>, ) -> Response { // Extract parameters let is_stream = body.stream; @@ -1889,6 +1865,7 @@ impl RouterTrait for PDRouter { is_stream, return_logprob, request_text, + model_id, }; // Execute with retry and bootstrap injection @@ -1899,6 +1876,7 @@ impl RouterTrait for PDRouter { &self, headers: Option<&HeaderMap>, body: &CompletionRequest, + model_id: Option<&str>, ) -> Response { // Extract parameters let is_stream = body.stream; @@ -1924,6 +1902,7 @@ impl RouterTrait for PDRouter { is_stream, return_logprob, request_text, + model_id, }; // Execute with retry and bootstrap injection @@ -1934,6 +1913,7 @@ impl RouterTrait for PDRouter { &self, _headers: Option<&HeaderMap>, _body: &ResponsesRequest, + _model_id: Option<&str>, ) -> Response { ( StatusCode::NOT_IMPLEMENTED, @@ -1946,7 +1926,12 @@ impl RouterTrait for PDRouter { todo!() } - async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response { + async fn route_rerank( + &self, + headers: Option<&HeaderMap>, + body: &RerankRequest, + model_id: Option<&str>, + ) -> Response { // Extract text for cache-aware routing let req_text = if self.policies_need_request_text() { Some(body.query.clone()) @@ -1961,6 +1946,7 @@ impl RouterTrait for PDRouter { is_stream: false, return_logprob: false, request_text: req_text, + model_id, }; // Execute with retry and bootstrap injection @@ -1970,10 +1956,16 @@ impl RouterTrait for PDRouter { async fn flush_cache(&self) -> Response { // Process both prefill and decode workers let (prefill_results, prefill_errors) = self - .process_workers(&self.prefill_workers, "Prefill", "flush_cache") + .process_workers( + WorkerType::Prefill { + bootstrap_port: None, + }, + "Prefill", + "flush_cache", + ) .await; let (decode_results, decode_errors) = self - .process_workers(&self.decode_workers, "Decode", "flush_cache") + .process_workers(WorkerType::Decode, "Decode", "flush_cache") .await; // Combine results and errors @@ -2005,37 +1997,29 @@ impl RouterTrait for PDRouter { let mut errors = Vec::new(); // Process prefill workers - match Self::get_worker_urls(&self.prefill_workers, "prefill") { - Ok(urls) => { - for worker_url in urls { - match get_worker_load(&self.client, &worker_url).await { - Some(load) => { - loads.insert(format!("prefill_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from prefill {}", worker_url)); - } - } + let prefill_urls = self.get_prefill_worker_urls(); + for worker_url in prefill_urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("prefill_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from prefill {}", worker_url)); } } - Err(e) => errors.push(e), } // Process decode workers - match Self::get_worker_urls(&self.decode_workers, "decode") { - Ok(urls) => { - for worker_url in urls { - match get_worker_load(&self.client, &worker_url).await { - Some(load) => { - loads.insert(format!("decode_{}", worker_url), load); - } - None => { - errors.push(format!("Failed to get load from decode {}", worker_url)); - } - } + let decode_urls = self.get_decode_worker_urls(); + for worker_url in decode_urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("decode_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from decode {}", worker_url)); } } - Err(e) => errors.push(e), } let response_data = serde_json::json!({ @@ -2052,24 +2036,15 @@ impl RouterTrait for PDRouter { fn readiness(&self) -> Response { // PD router is ready if it has at least one healthy prefill AND one healthy decode worker - let healthy_prefill_count = self - .prefill_workers - .read() - .unwrap() - .iter() - .filter(|w| w.is_healthy()) - .count(); + let prefill_workers = self.worker_registry.get_prefill_workers(); + let decode_workers = self.worker_registry.get_decode_workers(); - let healthy_decode_count = self - .decode_workers - .read() - .unwrap() - .iter() - .filter(|w| w.is_healthy()) - .count(); + let healthy_prefill_count = prefill_workers.iter().filter(|w| w.is_healthy()).count(); - let total_prefill = self.prefill_workers.read().unwrap().len(); - let total_decode = self.decode_workers.read().unwrap().len(); + let healthy_decode_count = decode_workers.iter().filter(|w| w.is_healthy()).count(); + + let total_prefill = prefill_workers.len(); + let total_decode = decode_workers.len(); if healthy_prefill_count > 0 && healthy_decode_count > 0 { Json(json!({ @@ -2117,17 +2092,15 @@ impl RouterTrait for PDRouter { mod tests { use super::*; use crate::core::{BasicWorker, WorkerType}; - use crate::policies::RandomPolicy; fn create_test_pd_router() -> PDRouter { - let prefill_policy = Arc::new(RandomPolicy::new()); - let decode_policy = Arc::new(RandomPolicy::new()); + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = + Arc::new(PolicyRegistry::new(crate::config::PolicyConfig::RoundRobin)); PDRouter { - prefill_workers: Arc::new(RwLock::new(vec![])), - decode_workers: Arc::new(RwLock::new(vec![])), - prefill_policy, - decode_policy, + worker_registry, + policy_registry, worker_startup_timeout_secs: 5, worker_startup_check_interval_secs: 1, worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), @@ -2137,8 +2110,6 @@ mod tests { prefill_drain_tx: mpsc::channel(100).0, retry_config: RetryConfig::default(), circuit_breaker_config: CircuitBreakerConfig::default(), - _prefill_health_checker: None, - _decode_health_checker: None, } } @@ -2162,12 +2133,14 @@ mod tests { }, true, ); - router.prefill_workers.write().unwrap().push(worker); + router.worker_registry.register(Arc::from(worker)); // Try to add the same URL again - this would fail during health check in real scenario // For unit test, we test the duplicate check logic - let workers = router.prefill_workers.read().unwrap(); - let exists = workers.iter().any(|w| w.url() == "http://localhost:8000"); + let exists = router + .worker_registry + .get_by_url("http://localhost:8000") + .is_some(); assert!(exists); } @@ -2191,8 +2164,8 @@ mod tests { true, ); - router.prefill_workers.write().unwrap().push(worker1); - router.prefill_workers.write().unwrap().push(worker2); + router.worker_registry.register(Arc::from(worker1)); + router.worker_registry.register(Arc::from(worker2)); // Remove one let result = router.remove_prefill_server("http://worker1").await; @@ -2200,7 +2173,7 @@ mod tests { assert!(result.is_ok()); assert!(result.unwrap().contains("Successfully removed")); - let workers = router.prefill_workers.read().unwrap(); + let workers = router.worker_registry.get_prefill_workers(); assert_eq!(workers.len(), 1); assert_eq!(workers[0].url(), "http://worker2"); } @@ -2226,44 +2199,42 @@ mod tests { // Add server first let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true); - router.decode_workers.write().unwrap().push(worker); + router.worker_registry.register(Arc::from(worker)); let result = router.remove_decode_server("http://decode1").await; assert!(result.is_ok()); assert!(result.unwrap().contains("Successfully removed")); - let workers = router.decode_workers.read().unwrap(); + let workers = router.worker_registry.get_decode_workers(); assert_eq!(workers.len(), 0); } // ============= Lock Error Handling Tests ============= #[test] - fn test_lock_operations() { + fn test_registry_operations() { let router = create_test_pd_router(); - // Test read/write locks work correctly - { - let read_guard = router.prefill_workers.read().unwrap(); - assert_eq!(read_guard.len(), 0); - } + // Test registry operations + let workers = router.worker_registry.get_all(); + assert_eq!(workers.len(), 0); - { - let mut write_guard = router.prefill_workers.write().unwrap(); - write_guard.push(create_test_worker( - "http://test".to_string(), - WorkerType::Prefill { - bootstrap_port: None, - }, - true, - )); - } + // Add a worker + let worker = create_test_worker( + "http://test".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router.worker_registry.register(Arc::from(worker)); - { - let read_guard = router.prefill_workers.read().unwrap(); - assert_eq!(read_guard.len(), 1); - } + let workers = router.worker_registry.get_all(); + assert_eq!(workers.len(), 1); + + let prefill_workers = router.worker_registry.get_prefill_workers(); + assert_eq!(prefill_workers.len(), 1); } // ============= Bootstrap Injection Tests ============= @@ -2297,15 +2268,11 @@ mod tests { let decode_worker = create_test_worker("http://decode".to_string(), WorkerType::Decode, true); - router - .prefill_workers - .write() - .unwrap() - .push(unhealthy_worker); - router.prefill_workers.write().unwrap().push(healthy_worker); - router.decode_workers.write().unwrap().push(decode_worker); + router.worker_registry.register(Arc::from(unhealthy_worker)); + router.worker_registry.register(Arc::from(healthy_worker)); + router.worker_registry.register(Arc::from(decode_worker)); - let result = router.select_pd_pair(None).await; + let result = router.select_pd_pair(None, None).await; assert!(result.is_ok()); let (prefill, _decode) = result.unwrap(); @@ -2319,7 +2286,7 @@ mod tests { async fn test_empty_worker_lists() { let router = create_test_pd_router(); - let result = router.select_pd_pair(None).await; + let result = router.select_pd_pair(None, None).await; assert!(result.is_err()); assert!(result.unwrap_err().contains("No prefill workers available")); @@ -2331,7 +2298,7 @@ mod tests { async fn test_health_endpoints() { let router = create_test_pd_router(); - // Add healthy workers + // Add healthy workers - create_test_worker returns Box, convert to Arc let prefill_worker = create_test_worker( "http://localhost:8000".to_string(), WorkerType::Prefill { @@ -2345,8 +2312,8 @@ mod tests { true, ); - router.prefill_workers.write().unwrap().push(prefill_worker); - router.decode_workers.write().unwrap().push(decode_worker); + router.worker_registry.register(Arc::from(prefill_worker)); + router.worker_registry.register(Arc::from(decode_worker)); // Test health endpoint let http_req = axum::http::Request::builder() @@ -2367,8 +2334,13 @@ mod tests { async fn test_load_monitor_updates() { let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); let mut router = create_test_pd_router(); - router.prefill_policy = power_of_two_policy.clone(); - router.decode_policy = power_of_two_policy; + // Set power_of_two policies in the registry + router + .policy_registry + .set_prefill_policy(power_of_two_policy.clone()); + router + .policy_registry + .set_decode_policy(power_of_two_policy); // Create load channel let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); @@ -2423,7 +2395,7 @@ mod tests { let router = create_test_pd_router(); - // Add workers + // Add workers - create_test_worker returns Box, convert to Arc let prefill_worker = create_test_worker( "http://prefill".to_string(), WorkerType::Prefill { @@ -2434,18 +2406,15 @@ mod tests { let decode_worker = create_test_worker("http://decode".to_string(), WorkerType::Decode, true); - router.prefill_workers.write().unwrap().push(prefill_worker); - router.decode_workers.write().unwrap().push(decode_worker); + router.worker_registry.register(Arc::from(prefill_worker)); + router.worker_registry.register(Arc::from(decode_worker)); - // Get references to the workers - clone to avoid holding lock across await - let (prefill_ref, decode_ref) = { - let workers = router.prefill_workers.read().unwrap(); - let prefill = workers[0].clone_worker(); - drop(workers); - let workers = router.decode_workers.read().unwrap(); - let decode = workers[0].clone_worker(); - (prefill, decode) - }; + // Get references to the workers from registry + let prefill_workers = router.worker_registry.get_prefill_workers(); + let decode_workers = router.worker_registry.get_decode_workers(); + + let prefill_ref = prefill_workers[0].clone(); + let decode_ref = decode_workers[0].clone(); // Initially load should be 0 assert_eq!(prefill_ref.load(), 0); @@ -2512,7 +2481,7 @@ mod tests { }, true, ); - router_clone.prefill_workers.write().unwrap().push(worker); + router_clone.worker_registry.register(Arc::from(worker)); }); handles.push(handle); } @@ -2523,7 +2492,7 @@ mod tests { } // Check final state - let workers = router.prefill_workers.read().unwrap(); + let workers = router.worker_registry.get_prefill_workers(); assert_eq!(workers.len(), 5); } } diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index ca1b4d68f..8b928ea37 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,10 +1,10 @@ use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, - RetryExecutor, Worker, WorkerFactory, WorkerType, + is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, + WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; -use crate::policies::LoadBalancingPolicy; +use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest, RerankResponse, RerankResult, ResponsesRequest, @@ -22,7 +22,7 @@ use axum::{ use futures_util::StreamExt; use reqwest::Client; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; @@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn}; /// Regular router that uses injected load balancing policies #[derive(Debug)] pub struct Router { - workers: Arc>>>, - policy: Arc, + worker_registry: Arc, + policy_registry: Arc, client: Client, worker_startup_timeout_secs: u64, worker_startup_check_interval_secs: u64, @@ -41,7 +41,6 @@ pub struct Router { circuit_breaker_config: CircuitBreakerConfig, _worker_loads: Arc>>, _load_monitor_handle: Option>>, - _health_checker: Option, } impl Router { @@ -49,7 +48,6 @@ impl Router { #[allow(clippy::too_many_arguments)] pub async fn new( worker_urls: Vec, - policy: Arc, ctx: &Arc, ) -> Result { // Update active workers gauge @@ -82,45 +80,51 @@ impl Router { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Create Worker trait objects from URLs with health check config - let workers: Vec> = worker_urls - .iter() - .map(|url| { - let worker = BasicWorker::new(url.clone(), WorkerType::Regular) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }); - Box::new(worker) as Box - }) - .collect(); + // Register workers in the registry + // In IGW mode, we need to fetch model info from workers + for url in &worker_urls { + // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint + // For now, create worker without model_id + let worker = BasicWorker::new(url.clone(), WorkerType::Regular) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }); - // Initialize policy with workers if needed (e.g., for cache-aware) - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - cache_aware.init_workers(&workers); + let worker_arc = Arc::new(worker); + ctx.worker_registry.register(worker_arc.clone()); + + // Notify PolicyRegistry about the new worker + let model_id = worker_arc.model_id(); + let policy = ctx.policy_registry.on_worker_added(model_id, None); + + // If this is a cache-aware policy and it's the first worker for this model, + // initialize it with the worker + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + let worker_dyn: Arc = worker_arc.clone(); + cache_aware.init_workers(std::slice::from_ref(&worker_dyn)); + } + } } - let workers = Arc::new(RwLock::new(workers)); - let health_checker = crate::core::start_health_checker( - Arc::clone(&workers), - ctx.router_config.worker_startup_check_interval_secs, - ); - // Setup load monitoring for PowerOfTwo policy let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); - let load_monitor_handle = if policy.name() == "power_of_two" { + // Check if default policy is power_of_two for load monitoring + let default_policy = ctx.policy_registry.get_default_policy(); + let load_monitor_handle = if default_policy.name() == "power_of_two" { let monitor_urls = worker_urls.clone(); let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; - let policy_clone = Arc::clone(&policy); + let policy_clone = default_policy.clone(); let client_clone = ctx.client.clone(); Some(Arc::new(tokio::spawn(async move { @@ -138,8 +142,8 @@ impl Router { }; Ok(Router { - workers, - policy, + worker_registry: ctx.worker_registry.clone(), + policy_registry: ctx.policy_registry.clone(), client: ctx.client.clone(), worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs, worker_startup_check_interval_secs: ctx @@ -151,18 +155,21 @@ impl Router { circuit_breaker_config: core_cb_config, _worker_loads: worker_loads, _load_monitor_handle: load_monitor_handle, - _health_checker: Some(health_checker), }) } /// Get the current list of worker URLs pub fn get_worker_urls(&self) -> Vec { - self.workers - .read() - .unwrap() - .iter() - .map(|w| w.url().to_string()) - .collect() + self.worker_registry.get_all_urls() + } + + /// Get worker URLs for a specific model + pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec { + let workers = match model_id { + Some(model) => self.worker_registry.get_by_model_fast(model), + None => self.worker_registry.get_all(), + }; + workers.iter().map(|w| w.url().to_string()).collect() } pub async fn wait_for_healthy_workers( @@ -332,11 +339,27 @@ impl Router { } fn select_first_worker(&self) -> Result { - let workers_guard = self.workers.read().unwrap(); - if workers_guard.is_empty() { + let workers = self.worker_registry.get_all(); + if workers.is_empty() { Err("No workers are available".to_string()) } else { - Ok(workers_guard[0].url().to_string()) + Ok(workers[0].url().to_string()) + } + } + + #[allow(dead_code)] + fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result { + let workers = match model_id { + Some(model) => self.worker_registry.get_by_model_fast(model), + None => self.worker_registry.get_all(), + }; + if workers.is_empty() { + Err(format!( + "No workers are available for model: {:?}", + model_id + )) + } else { + Ok(workers[0].url().to_string()) } } @@ -447,20 +470,35 @@ impl Router { } } - // New method to route typed requests directly - /// Select worker considering circuit breaker state - fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option> { - let workers = self.workers.read().ok()?; - let available: Vec> = workers + /// Select worker for a specific model considering circuit breaker state + fn select_worker_for_model( + &self, + model_id: Option<&str>, + text: Option<&str>, + ) -> Option> { + // Get workers for the specified model (O(1) lookup if model_id is provided) + let workers = match model_id { + Some(model) => self.worker_registry.get_by_model_fast(model), + None => self.worker_registry.get_all(), + }; + + let available: Vec> = workers .iter() .filter(|w| w.is_available()) - .map(|w| w.clone_worker()) + .cloned() .collect(); if available.is_empty() { return None; } - let idx = self.policy.select_worker(&available, text)?; - Some(available[idx].clone_worker()) + + // Get the appropriate policy for this model + let policy = match model_id { + Some(model) => self.policy_registry.get_policy_or_default(model), + None => self.policy_registry.get_default_policy(), + }; + + let idx = policy.select_worker(&available, text)?; + Some(available[idx].clone()) } pub async fn route_typed_request( @@ -468,6 +506,7 @@ impl Router { headers: Option<&HeaderMap>, typed_req: &T, route: &str, + model_id: Option<&str>, ) -> Response { let start = Instant::now(); let is_stream = typed_req.is_stream(); @@ -477,7 +516,7 @@ impl Router { &self.retry_config, // operation per attempt |_: u32| async { - let worker = match self.select_worker_with_circuit_breaker(Some(&text)) { + let worker = match self.select_worker_for_model(model_id, Some(&text)) { Some(w) => w, None => { RouterMetrics::record_request_error(route, "no_available_workers"); @@ -490,7 +529,13 @@ impl Router { }; // Optional load tracking for cache-aware policy - let load_incremented = if self.policy.name() == "cache_aware" { + // Get the policy for this model to check if it's cache-aware + let policy = match model_id { + Some(model) => self.policy_registry.get_policy_or_default(model), + None => self.policy_registry.get_default_policy(), + }; + + let load_incremented = if policy.name() == "cache_aware" { worker.increment_load(); RouterMetrics::set_running_requests(worker.url(), worker.load()); true @@ -654,11 +699,9 @@ impl Router { // Decrement load on error if it was incremented if load_incremented { - if let Ok(workers_guard) = self.workers.read() { - if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); - } + if let Some(worker) = self.worker_registry.get_by_url(worker_url) { + worker.decrement_load(); + RouterMetrics::set_running_requests(worker_url, worker.load()); } } @@ -687,13 +730,9 @@ impl Router { Err(e) => { // IMPORTANT: Decrement load on error before returning if load_incremented { - if let Ok(workers_guard) = self.workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == worker_url) - { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); - } + if let Some(worker) = self.worker_registry.get_by_url(worker_url) { + worker.decrement_load(); + RouterMetrics::set_running_requests(worker_url, worker.load()); } } @@ -704,18 +743,16 @@ impl Router { // Decrement load counter for non-streaming requests if it was incremented if load_incremented { - if let Ok(workers_guard) = self.workers.read() { - if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); - } + if let Some(worker) = self.worker_registry.get_by_url(worker_url) { + worker.decrement_load(); + RouterMetrics::set_running_requests(worker_url, worker.load()); } } response } else if load_incremented { // For streaming with load tracking, we need to manually decrement when done - let workers = Arc::clone(&self.workers); + let registry = Arc::clone(&self.worker_registry); let worker_url = worker_url.to_string(); // Preserve headers for streaming response @@ -739,17 +776,10 @@ impl Router { .windows(12) .any(|window| window == b"data: [DONE]") { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == worker_url) - { - worker.decrement_load(); - RouterMetrics::set_running_requests( - &worker_url, - worker.load(), - ); - decremented = true; - } + if let Some(worker) = registry.get_by_url(&worker_url) { + worker.decrement_load(); + RouterMetrics::set_running_requests(&worker_url, worker.load()); + decremented = true; } } if tx.send(Ok(bytes)).is_err() { @@ -763,11 +793,9 @@ impl Router { } } if !decremented { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(&worker_url, worker.load()); - } + if let Some(worker) = registry.get_by_url(&worker_url) { + worker.decrement_load(); + RouterMetrics::set_running_requests(&worker_url, worker.load()); } } }); @@ -839,7 +867,6 @@ impl Router { match client.get(format!("{}/health", worker_url)).send().await { Ok(res) => { if res.status().is_success() { - let mut workers_guard = self.workers.write().unwrap(); if self.dp_aware { // Need to contact the worker to extract the dp_size, // and add them as multiple workers @@ -848,46 +875,77 @@ impl Router { .map_err(|e| format!("Failed to get dp-aware workers: {}", e))?; let mut worker_added: bool = false; for dp_url in &dp_url_vec { - if workers_guard.iter().any(|w| w.url() == dp_url) { + if self.worker_registry.get_by_url(dp_url).is_some() { warn!("Worker {} already exists", dp_url); continue; } info!("Added worker: {}", dp_url); - let new_worker = WorkerFactory::create_regular_with_config( - dp_url.to_string(), - self.circuit_breaker_config.clone(), - ); - workers_guard.push(new_worker); + // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint + let new_worker = + BasicWorker::new(dp_url.to_string(), WorkerType::Regular) + .with_circuit_breaker_config( + self.circuit_breaker_config.clone(), + ); + + let worker_arc = Arc::new(new_worker); + self.worker_registry.register(worker_arc.clone()); + + // Notify PolicyRegistry about the new worker + let model_id = worker_arc.model_id(); + let policy = self.policy_registry.on_worker_added(model_id, None); + + // If this is a cache-aware policy, update it with all workers for this model + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::( + ) { + let model_workers = + self.worker_registry.get_by_model_fast(model_id); + cache_aware.init_workers(&model_workers); + } + } + worker_added = true; } if !worker_added { return Err(format!("No worker added for {}", worker_url)); } } else { - if workers_guard.iter().any(|w| w.url() == worker_url) { + if self.worker_registry.get_by_url(worker_url).is_some() { return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); - let new_worker = WorkerFactory::create_regular_with_config( - worker_url.to_string(), - self.circuit_breaker_config.clone(), - ); - workers_guard.push(new_worker); + + // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint + let new_worker = + BasicWorker::new(worker_url.to_string(), WorkerType::Regular) + .with_circuit_breaker_config( + self.circuit_breaker_config.clone(), + ); + + let worker_arc = Arc::new(new_worker); + self.worker_registry.register(worker_arc.clone()); + + // Notify PolicyRegistry about the new worker + let model_id = worker_arc.model_id(); + let policy = self.policy_registry.on_worker_added(model_id, None); + + // If this is a cache-aware policy, add this worker to it + if policy.name() == "cache_aware" { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::( + ) { + // Get all workers for this model + let model_workers = + self.worker_registry.get_by_model_fast(model_id); + cache_aware.init_workers(&model_workers); + } + } } - RouterMetrics::set_active_workers(workers_guard.len()); - - // If cache aware policy, initialize the worker in the tree - if let Some(cache_aware) = - self.policy - .as_any() - .downcast_ref::() - { - // Get updated workers after adding - drop(workers_guard); - let workers_guard = self.workers.read().unwrap(); - cache_aware.init_workers(&workers_guard); - } + RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); return Ok(format!("Successfully added worker: {}", worker_url)); } else { @@ -931,66 +989,73 @@ impl Router { if self.dp_aware { // remove dp-aware workers in a prefix-matching fashion // without contacting the remote worker - let mut candidate_workers: Vec = Vec::new(); let mut removed_workers: Vec = Vec::new(); let worker_url_prefix = format!("{}@", worker_url); - { - // find the candidate workers to be removed - let workers_guard = self.workers.read().unwrap(); - for w in workers_guard.iter() { - if w.url().starts_with(&worker_url_prefix) { - candidate_workers.push(w.url().to_string()); - } - } - } + // Find and remove all workers with matching prefix + let all_workers = self.worker_registry.get_all(); + for w in all_workers.iter() { + if w.url().starts_with(&worker_url_prefix) { + // Get model_id before removing + let model_id = w.model_id().to_string(); - { - // do the removing on the worker_urls - let mut workers_guard = self.workers.write().unwrap(); - for dp_url in candidate_workers.iter() { - if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) { - workers_guard.remove(index); - info!("Removed worker: {}", dp_url); - removed_workers.push(dp_url.to_string()); + if self.worker_registry.remove_by_url(w.url()).is_some() { + info!("Removed worker: {}", w.url()); + removed_workers.push(w.url().to_string()); + + // Notify PolicyRegistry about the removed worker + self.policy_registry.on_worker_removed(&model_id); } else { - warn!("Worker {} not found, skipping removal", dp_url); - continue; + warn!("Worker {} not found, skipping removal", w.url()); } } - RouterMetrics::set_active_workers(workers_guard.len()); } - // If cache aware policy, remove the workers from the tree - if let Some(cache_aware) = self - .policy - .as_any() - .downcast_ref::() - { - for dp_url in removed_workers.iter() { - cache_aware.remove_worker(dp_url); - info!("Removed worker from tree: {}", dp_url); + RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); + + // If any models are using cache aware policy, remove the workers from the tree + // Check each removed worker's model and get its policy + for dp_url in removed_workers.iter() { + if let Some(worker) = self.worker_registry.get_by_url(dp_url) { + let model_id = worker.model_id(); + if let Some(policy) = self.policy_registry.get_policy(model_id) { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker_by_url(dp_url); + info!("Removed worker from cache-aware tree: {}", dp_url); + } + } } } } else { - let mut workers_guard = self.workers.write().unwrap(); - if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { - workers_guard.remove(index); - info!("Removed worker: {}", worker_url); - RouterMetrics::set_active_workers(workers_guard.len()); + // Get the worker first to extract model_id + let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) { + worker.model_id().to_string() } else { warn!("Worker {} not found, skipping removal", worker_url); return; + }; + + if self.worker_registry.remove_by_url(worker_url).is_some() { + info!("Removed worker: {}", worker_url); + + // Notify PolicyRegistry about the removed worker + self.policy_registry.on_worker_removed(&model_id); + + RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); } - // If cache aware policy, remove the workers from the tree - if let Some(cache_aware) = self - .policy - .as_any() - .downcast_ref::() - { - cache_aware.remove_worker(worker_url); - info!("Removed worker from tree: {}", worker_url); + // If the model is using cache aware policy, remove the worker from the tree + if let Some(policy) = self.policy_registry.get_policy(&model_id) { + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker_by_url(worker_url); + info!("Removed worker from cache-aware tree: {}", worker_url); + } } } } @@ -1171,7 +1236,7 @@ impl RouterTrait for Router { } async fn health(&self, _req: Request) -> Response { - let workers = self.workers.read().unwrap(); + let workers = self.worker_registry.get_all(); let unhealthy_servers: Vec<_> = workers .iter() .filter(|w| !w.is_healthy()) @@ -1209,16 +1274,19 @@ impl RouterTrait for Router { &self, headers: Option<&HeaderMap>, body: &GenerateRequest, + model_id: Option<&str>, ) -> Response { - self.route_typed_request(headers, body, "/generate").await + self.route_typed_request(headers, body, "/generate", model_id) + .await } async fn route_chat( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, + model_id: Option<&str>, ) -> Response { - self.route_typed_request(headers, body, "/v1/chat/completions") + self.route_typed_request(headers, body, "/v1/chat/completions", model_id) .await } @@ -1226,8 +1294,9 @@ impl RouterTrait for Router { &self, headers: Option<&HeaderMap>, body: &CompletionRequest, + model_id: Option<&str>, ) -> Response { - self.route_typed_request(headers, body, "/v1/completions") + self.route_typed_request(headers, body, "/v1/completions", model_id) .await } @@ -1235,8 +1304,9 @@ impl RouterTrait for Router { &self, headers: Option<&HeaderMap>, body: &ResponsesRequest, + model_id: Option<&str>, ) -> Response { - self.route_typed_request(headers, body, "/v1/responses") + self.route_typed_request(headers, body, "/v1/responses", model_id) .await } @@ -1244,11 +1314,18 @@ impl RouterTrait for Router { todo!() } - async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response { + async fn route_rerank( + &self, + headers: Option<&HeaderMap>, + body: &RerankRequest, + model_id: Option<&str>, + ) -> Response { if let Err(e) = body.validate() { return (StatusCode::BAD_REQUEST, e).into_response(); } - let response = self.route_typed_request(headers, body, "/v1/rerank").await; + let response = self + .route_typed_request(headers, body, "/v1/rerank", model_id) + .await; if response.status().is_success() { match Self::build_rerank_response(body, response).await { Ok(rerank_response) => rerank_response, @@ -1340,19 +1417,15 @@ impl RouterTrait for Router { fn readiness(&self) -> Response { // Regular router is ready if it has at least one healthy worker - let healthy_count = self - .workers - .read() - .unwrap() - .iter() - .filter(|w| w.is_healthy()) - .count(); + let workers = self.worker_registry.get_all(); + let healthy_count = workers.iter().filter(|w| w.is_healthy()).count(); + let total_workers = workers.len(); if healthy_count > 0 { Json(serde_json::json!({ "status": "ready", "healthy_workers": healthy_count, - "total_workers": self.workers.read().unwrap().len() + "total_workers": total_workers })) .into_response() } else { @@ -1361,7 +1434,7 @@ impl RouterTrait for Router { Json(serde_json::json!({ "status": "not_ready", "reason": "no healthy workers available", - "total_workers": self.workers.read().unwrap().len() + "total_workers": total_workers })), ) .into_response() @@ -1372,18 +1445,25 @@ impl RouterTrait for Router { #[cfg(test)] mod tests { use super::*; - use crate::policies::RandomPolicy; use std::collections::HashMap; fn create_test_regular_router() -> Router { - let workers = vec![ - WorkerFactory::create_regular("http://worker1:8080".to_string()), - WorkerFactory::create_regular("http://worker2:8080".to_string()), - ]; + // Create registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new( + crate::config::types::PolicyConfig::RoundRobin, + )); + + // Register test workers + let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); + worker_registry.register(Arc::new(worker1)); + worker_registry.register(Arc::new(worker2)); + let (_, rx) = tokio::sync::watch::channel(HashMap::new()); Router { - workers: Arc::new(RwLock::new(workers)), - policy: Arc::new(RandomPolicy::new()), + worker_registry, + policy_registry, worker_startup_timeout_secs: 5, worker_startup_check_interval_secs: 1, dp_aware: false, @@ -1393,7 +1473,6 @@ mod tests { circuit_breaker_config: CircuitBreakerConfig::default(), _worker_loads: Arc::new(rx), _load_monitor_handle: None, - _health_checker: None, } } @@ -1413,7 +1492,9 @@ mod tests { let result = router.select_first_worker(); assert!(result.is_ok()); - assert_eq!(result.unwrap(), "http://worker1:8080"); + let url = result.unwrap(); + // DashMap doesn't guarantee order, so just check we get one of the workers + assert!(url == "http://worker1:8080" || url == "http://worker2:8080"); } #[tokio::test] diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 3fe339d8f..fba121002 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -17,6 +17,7 @@ pub mod factory; pub mod grpc; pub mod header_utils; pub mod http; +pub mod router_manager; pub use factory::RouterFactory; // Re-export HTTP routers for convenience (keeps routers::openai_router path working) @@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { async fn get_model_info(&self, req: Request) -> Response; /// Route a generate request - async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest) - -> Response; + async fn route_generate( + &self, + headers: Option<&HeaderMap>, + body: &GenerateRequest, + model_id: Option<&str>, + ) -> Response; /// Route a chat completion request async fn route_chat( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, + model_id: Option<&str>, ) -> Response; /// Route a completion request @@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { &self, headers: Option<&HeaderMap>, body: &CompletionRequest, + model_id: Option<&str>, ) -> Response; /// Route a responses request @@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { &self, headers: Option<&HeaderMap>, body: &ResponsesRequest, + model_id: Option<&str>, ) -> Response; async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; - async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response; + async fn route_rerank( + &self, + headers: Option<&HeaderMap>, + body: &RerankRequest, + model_id: Option<&str>, + ) -> Response; /// Flush cache on all workers async fn flush_cache(&self) -> Response; diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs new file mode 100644 index 000000000..6ec325f40 --- /dev/null +++ b/sgl-router/src/routers/router_manager.rs @@ -0,0 +1,766 @@ +//! Router Manager for coordinating multiple routers and workers +//! +//! Provides centralized management based on enable_igw flag: +//! - Single Router Mode (enable_igw=false): Router owns workers directly +//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything + +use crate::config::RouterConfig; +use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry}; +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, +}; +use crate::protocols::worker_spec::{ + ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo, + WorkerListResponse, WorkerStats, WorkerTypeStats, +}; +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use axum::{ + body::Body, + extract::Request, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use dashmap::DashMap; +use std::sync::Arc; +use tracing::{info, warn}; + +/// Router identifier +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct RouterId(String); + +impl RouterId { + pub fn new(id: String) -> Self { + Self(id) + } + + pub fn as_str(&self) -> &str { + &self.0 + } +} + +/// Router Manager - Central coordinator for routers and workers +/// Only created when enable_igw=true +pub struct RouterManager { + /// Worker registry (single source of truth in multi-router mode) + worker_registry: Arc, + + /// Policy registry for managing model-to-policy mappings + policy_registry: Arc, + + /// All routers managed by this manager (max 4 routers in Phase 2) + /// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd" + routers: Arc>>, + + /// Default router for requests without specific routing + default_router: Option, + + /// Model to router mapping for model-aware routing + /// Multiple models can be served by the same router + model_routers: Arc>>, + + /// HTTP client for querying worker info + client: reqwest::Client, + + /// Configuration + #[allow(dead_code)] // May be used in future enhancements + config: RouterConfig, +} + +impl RouterManager { + /// Create a new router manager with shared registries + pub fn new( + config: RouterConfig, + client: reqwest::Client, + worker_registry: Arc, + policy_registry: Arc, + ) -> Self { + Self { + worker_registry, + policy_registry, + routers: Arc::new(DashMap::new()), + default_router: None, + model_routers: Arc::new(DashMap::new()), + client, + config, + } + } + + /// Register a router with the manager + pub fn register_router( + &mut self, + id: RouterId, + router: Arc, + models: Vec, + ) { + // Store router + self.routers.insert(id.clone(), router); + + // Update model mappings + for model in models { + self.model_routers + .entry(model) + .or_default() + .push(id.clone()); + } + + // Set as default if first router + if self.default_router.is_none() { + self.default_router = Some(id.clone()); + info!("Set default router to {}", id.as_str()); + } + } + + /// Set the default router + pub fn set_default_router(&mut self, id: RouterId) { + self.default_router = Some(id); + } + + /// Get the number of registered routers + pub fn router_count(&self) -> usize { + self.routers.len() + } + + /// Get router for a specific model + pub fn get_router_for_model(&self, model_id: &str) -> Option> { + // First try model-specific routers + if let Some(router_ids) = self.model_routers.get(model_id) { + if let Some(router_id) = router_ids.first() { + if let Some(router) = self.routers.get(router_id) { + return Some(router.clone()); + } + } + } + + // Fall back to default router + if let Some(ref default_id) = self.default_router { + self.routers.get(default_id).map(|r| r.clone()) + } else { + None + } + } + + /// Get workers for routing decision + pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec> { + if let Some(model) = model_id { + self.worker_registry.get_by_model(model) + } else { + self.worker_registry.get_all() + } + } + + /// Add a worker to the registry + pub async fn add_worker( + &self, + config: WorkerConfigRequest, + ) -> Result { + // Build labels from configuration + let mut labels = config.labels.clone(); + + // Query server info if model_id not provided + let model_id = if let Some(model_id) = config.model_id { + model_id + } else { + match self.query_server_info(&config.url).await { + Ok(info) => { + // Extract model_id from server info + info.model_id + .or_else(|| { + info.model_path + .as_ref() + .and_then(|path| path.split('/').next_back().map(|s| s.to_string())) + }) + .unwrap_or_else(|| "unknown".to_string()) + } + Err(e) => { + warn!("Failed to query server info from {}: {}", config.url, e); + "unknown".to_string() + } + } + }; + + // Add configuration to labels + labels.insert("model_id".to_string(), model_id.clone()); + + if let Some(priority) = config.priority { + labels.insert("priority".to_string(), priority.to_string()); + } + + if let Some(cost) = config.cost { + labels.insert("cost".to_string(), cost.to_string()); + } + + // Add gRPC-specific configuration if provided + if let Some(tokenizer_path) = config.tokenizer_path { + labels.insert("tokenizer_path".to_string(), tokenizer_path); + } + + if let Some(reasoning_parser) = config.reasoning_parser { + labels.insert("reasoning_parser".to_string(), reasoning_parser); + } + + if let Some(tool_parser) = config.tool_parser { + labels.insert("tool_parser".to_string(), tool_parser); + } + + if let Some(chat_template) = config.chat_template { + labels.insert("chat_template".to_string(), chat_template); + } + + // Create worker based on type + // Note: For prefill and decode workers, we can't easily add labels after creation + // since they return Box. We'll need to enhance WorkerFactory in the future. + let worker = match config.worker_type.as_deref() { + Some("prefill") => { + // For now, prefill workers won't have custom labels + // TODO: Enhance WorkerFactory to accept labels for prefill workers + WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port) + } + Some("decode") => { + // For now, decode workers won't have custom labels + // TODO: Enhance WorkerFactory to accept labels for decode workers + WorkerFactory::create_decode(config.url.clone()) + } + _ => { + // Regular workers can have labels + WorkerFactory::create_regular_with_labels( + config.url.clone(), + labels.clone(), + CircuitBreakerConfig::default(), + ) + } + }; + + // Register worker + let worker_id = self.worker_registry.register(Arc::from(worker)); + + // Notify PolicyRegistry about the new worker + // Extract policy hint from labels if provided + let policy_hint = labels.get("policy").map(|s| s.as_str()); + let policy = self.policy_registry.on_worker_added(&model_id, policy_hint); + + info!( + "Added worker {} with URL {} for model {} using policy {}", + worker_id.as_str(), + config.url, + model_id, + policy.name() + ); + + // Return worker info + let worker_arc = self.worker_registry.get(&worker_id).unwrap(); + let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc); + + Ok(WorkerApiResponse { + success: true, + message: format!("Worker {} added successfully", worker_id.as_str()), + worker: Some(worker_info), + }) + } + + /// Remove a worker from the registry + pub fn remove_worker_from_registry( + &self, + url: &str, + ) -> Result { + // Get worker to extract model_id before removing + let model_id = self + .worker_registry + .get_by_url(url) + .map(|worker| worker.model_id().to_string()); + + if let Some(_worker) = self.worker_registry.remove_by_url(url) { + // Notify PolicyRegistry about worker removal + if let Some(model_id) = model_id { + self.policy_registry.on_worker_removed(&model_id); + info!("Removed worker with URL {} for model {}", url, model_id); + } else { + info!("Removed worker with URL {}", url); + } + + Ok(WorkerApiResponse { + success: true, + message: format!("Worker {} removed successfully", url), + worker: None, + }) + } else { + Err(WorkerErrorResponse { + error: format!("Worker with URL {} not found", url), + code: "WORKER_NOT_FOUND".to_string(), + }) + } + } + + /// List all workers + pub fn list_workers(&self) -> WorkerListResponse { + let workers = self.worker_registry.get_all_with_ids(); + let worker_infos: Vec = workers + .iter() + .map(|(id, w)| self.worker_to_info(id.as_str(), w)) + .collect(); + + let total = worker_infos.len(); + + // Get stats from the worker registry + let registry_stats = self.worker_registry.stats(); + + // Convert WorkerRegistryStats to WorkerStats + let stats = WorkerStats { + total_workers: registry_stats.total_workers, + healthy_workers: registry_stats.healthy_workers, + total_models: registry_stats.total_models, + total_load: registry_stats.total_load, + by_type: WorkerTypeStats { + regular: registry_stats.regular_workers, + prefill: registry_stats.prefill_workers, + decode: registry_stats.decode_workers, + }, + }; + + WorkerListResponse { + workers: worker_infos, + total, + stats, + } + } + + /// Get worker by URL + pub fn get_worker(&self, url: &str) -> Option { + self.worker_registry + .get_by_url(url) + .map(|w| self.worker_to_info("unknown", &w)) + } + + /// Query server info from a worker URL + async fn query_server_info(&self, url: &str) -> Result { + let info_url = format!("{}/get_server_info", url.trim_end_matches('/')); + + match self.client.get(&info_url).send().await { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| format!("Failed to parse server info: {}", e)) + } else { + Err(format!("Server returned status: {}", response.status())) + } + } + Err(e) => Err(format!("Failed to connect to server: {}", e)), + } + } + + /// Convert Worker to WorkerInfo + fn worker_to_info(&self, id: &str, worker: &Arc) -> WorkerInfo { + let metadata = worker.metadata(); + + WorkerInfo { + id: id.to_string(), + url: worker.url().to_string(), + model_id: worker.model_id().to_string(), + priority: worker.priority(), + cost: worker.cost(), + worker_type: format!("{:?}", worker.worker_type()), + is_healthy: worker.is_healthy(), + load: worker.load(), + connection_mode: format!("{:?}", worker.connection_mode()), + tokenizer_path: worker.tokenizer_path().map(|s| s.to_string()), + reasoning_parser: worker.reasoning_parser().map(|s| s.to_string()), + tool_parser: worker.tool_parser().map(|s| s.to_string()), + chat_template: worker.chat_template().map(|s| s.to_string()), + metadata: metadata.labels.clone(), + } + } + + // Note: calculate_stats removed - using WorkerRegistry::stats() instead + + // === Phase 2: Router Management === + // Note: Dynamic router creation removed - routers are created and registered externally + + /// Get the appropriate router for a request based on headers and request content + pub fn select_router_for_request( + &self, + headers: Option<&HeaderMap>, + model_id: Option<&str>, + ) -> Option> { + // Extract priority and cost preferences from headers if available + let _priority_threshold = headers.and_then(|h| { + h.get("x-worker-priority") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + }); + + let _max_cost = headers.and_then(|h| { + h.get("x-max-cost") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + }); + + // Check if PD (prefill-decode) mode is preferred from headers + let prefer_pd = headers + .and_then(|h| { + h.get("x-prefer-pd") + .and_then(|v| v.to_str().ok()) + .map(|s| s == "true" || s == "1") + }) + .unwrap_or(false); + + // If model specified, find routers serving that model + let candidate_routers = if let Some(model) = model_id { + // Get routers for specific model + if let Some(router_ids) = self.model_routers.get(model) { + router_ids + .iter() + .filter_map(|id| self.routers.get(id).map(|r| r.clone())) + .collect::>() + } else { + Vec::new() + } + } else { + // No model specified, consider all routers + self.routers + .iter() + .map(|entry| entry.value().clone()) + .collect::>() + }; + + if candidate_routers.is_empty() { + // No routers found for the specified model + return None; + } + + // Score routers based on worker attributes and request preferences + let mut best_router = None; + let mut best_score = 0.0; + + for router in candidate_routers { + let mut score = 1.0; + + // Check if this is a PD router + let is_pd = router.is_pd_mode(); + if prefer_pd && is_pd { + score += 2.0; // Bonus for matching PD preference + } else if !prefer_pd && !is_pd { + score += 1.0; // Bonus for matching regular preference + } + + // Get workers for this router and evaluate based on priority/cost + // Note: This would require routers to expose their workers or stats + // For now, we'll use a simple selection based on router type + + // TODO: Once routers expose worker stats, we can evaluate: + // - Average worker priority vs priority_threshold + // - Average worker cost vs max_cost + // - Current load and health status + + if score > best_score { + best_score = score; + best_router = Some(router); + } + } + + best_router + } +} + +// Note: Default implementation removed as RouterManager now requires AppContext +// which cannot be defaulted. RouterManager must be created with explicit context. + +// === Phase 2: RouterManager as RouterTrait === + +/// RouterManager implements RouterTrait to act as a meta-router +/// that delegates requests to the appropriate underlying router +#[async_trait] +impl WorkerManagement for RouterManager { + /// Add a worker - in multi-router mode, this adds to the registry + async fn add_worker(&self, worker_url: &str) -> Result { + // Create a basic worker config request + let config = WorkerConfigRequest { + url: worker_url.to_string(), + model_id: None, + worker_type: None, + priority: None, + cost: None, + labels: std::collections::HashMap::new(), + bootstrap_port: None, + tokenizer_path: None, + reasoning_parser: None, + tool_parser: None, + chat_template: None, + }; + + match self.add_worker(config).await { + Ok(response) => Ok(response.message), + Err(e) => Err(e.error), + } + } + + /// Remove a worker from the registry + fn remove_worker(&self, worker_url: &str) { + let _ = self.remove_worker_from_registry(worker_url); + } + + /// Get all worker URLs from the registry + fn get_worker_urls(&self) -> Vec { + self.worker_registry.get_all_urls() + } +} + +#[async_trait] +impl RouterTrait for RouterManager { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Health check - return 503 if no routers available + async fn health(&self, _req: Request) -> Response { + // Health check should succeed if RouterManager exists, even without routers + // Individual router health can be checked via specific endpoints + (StatusCode::OK, "RouterManager is healthy").into_response() + } + + /// Health generate - check if any router can handle generate requests + async fn health_generate(&self, _req: Request) -> Response { + // Return 503 since we have no routers with workers + // TODO: Should check if any router has healthy workers + ( + StatusCode::SERVICE_UNAVAILABLE, + "No routers with healthy workers available", + ) + .into_response() + } + + /// Get server information - aggregate from all routers + async fn get_server_info(&self, _req: Request) -> Response { + // TODO: Aggregate info from all routers with healthy workers + // For now, return basic info about the RouterManager + ( + StatusCode::OK, + serde_json::json!({ + "router_manager": true, + "routers_count": self.routers.len(), + "workers_count": self.worker_registry.get_all().len() + }) + .to_string(), + ) + .into_response() + } + + /// Get available models - aggregate from all routers + async fn get_models(&self, _req: Request) -> Response { + // Return models that have registered routers + let models = self + .model_routers + .iter() + .map(|entry| entry.key().clone()) + .collect::>(); + + if models.is_empty() { + (StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response() + } else { + ( + StatusCode::OK, + serde_json::json!({ + "models": models + }) + .to_string(), + ) + .into_response() + } + } + + /// Get model information + async fn get_model_info(&self, _req: Request) -> Response { + // TODO: Extract model from request and route to appropriate router + // For now, return not implemented + ( + StatusCode::NOT_IMPLEMENTED, + "Model info endpoint not yet implemented in RouterManager", + ) + .into_response() + } + + /// Route a generate request + async fn route_generate( + &self, + headers: Option<&HeaderMap>, + body: &GenerateRequest, + _model_id: Option<&str>, + ) -> Response { + // Select router based on headers + // GenerateRequest doesn't have a model field + let router = self.select_router_for_request(headers, None); + + if let Some(router) = router { + // In multi-model mode, pass None since GenerateRequest doesn't have model field + router.route_generate(headers, body, None).await + } else { + // Return 404 when no router is available for the request + ( + StatusCode::NOT_FOUND, + "No router available for this request", + ) + .into_response() + } + } + + /// Route a chat completion request + async fn route_chat( + &self, + headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + _model_id: Option<&str>, + ) -> Response { + // Select router based on headers and model + let router = self.select_router_for_request(headers, Some(&body.model)); + + if let Some(router) = router { + // In multi-model mode, pass the model_id to the router + router.route_chat(headers, body, Some(&body.model)).await + } else { + // Return 404 when the specified model is not found + ( + StatusCode::NOT_FOUND, + format!("Model '{}' not found or no router available", body.model), + ) + .into_response() + } + } + + /// Route a completion request + async fn route_completion( + &self, + headers: Option<&HeaderMap>, + body: &CompletionRequest, + _model_id: Option<&str>, + ) -> Response { + // Select router based on headers and model + let router = self.select_router_for_request(headers, Some(&body.model)); + + if let Some(router) = router { + // In multi-model mode, pass the model_id to the router + router + .route_completion(headers, body, Some(&body.model)) + .await + } else { + // Return 404 when the specified model is not found + ( + StatusCode::NOT_FOUND, + format!("Model '{}' not found or no router available", body.model), + ) + .into_response() + } + } + + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _body: &ResponsesRequest, + _model_id: Option<&str>, + ) -> Response { + todo!() + } + + /// Route embeddings request + async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response { + // Try to select a router based on headers + let router = self.select_router_for_request(headers, None); + + if let Some(router) = router { + router.route_embeddings(headers, body).await + } else { + ( + StatusCode::NOT_FOUND, + "No router available for embeddings request", + ) + .into_response() + } + } + + /// Route rerank request + async fn route_rerank( + &self, + headers: Option<&HeaderMap>, + body: &RerankRequest, + model_id: Option<&str>, + ) -> Response { + // Try to select a router based on headers + let router = self.select_router_for_request(headers, None); + + if let Some(router) = router { + router.route_rerank(headers, body, model_id).await + } else { + ( + StatusCode::NOT_FOUND, + "No router available for rerank request", + ) + .into_response() + } + } + + /// Flush cache on all routers and workers + async fn flush_cache(&self) -> Response { + // TODO: Call flush_cache on all routers that have workers + // For now, return success if we have any routers + if self.routers.is_empty() { + (StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response() + } else { + // TODO: Actually flush cache on all routers + (StatusCode::OK, "Cache flush requested").into_response() + } + } + + /// Get worker loads from all routers + async fn get_worker_loads(&self) -> Response { + // Return worker loads from the registry + let workers = self.worker_registry.get_all(); + let loads: Vec = workers + .iter() + .map(|w| { + serde_json::json!({ + "url": w.url(), + "model": w.model_id(), + "load": w.load(), + "is_healthy": w.is_healthy() + }) + }) + .collect(); + + ( + StatusCode::OK, + serde_json::json!({ + "workers": loads + }) + .to_string(), + ) + .into_response() + } + + /// Get router type name + fn router_type(&self) -> &'static str { + "manager" + } + + /// Server readiness check - check if any router is ready + fn readiness(&self) -> Response { + if self.routers.is_empty() { + (StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response() + } else { + // TODO: Check readiness of all routers + (StatusCode::OK, "Ready").into_response() + } + } +} + +// Note: get_first_available_router removed - we now properly handle +// router selection based on model and worker availability + +impl std::fmt::Debug for RouterManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RouterManager") + .field("routers_count", &self.routers.len()) + .field("workers_count", &self.worker_registry.get_all().len()) + .field("default_router", &self.default_router) + .finish() + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index f44924e38..acaf6a19d 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,12 +1,16 @@ use crate::config::RouterConfig; +use crate::core::WorkerRegistry; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; use crate::middleware::TokenBucket; +use crate::policies::PolicyRegistry; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, V1RerankReqInput, }; +use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}; use crate::reasoning_parser::ParserFactory; +use crate::routers::router_manager::{RouterId, RouterManager}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer}; @@ -36,6 +40,9 @@ pub struct AppContext { pub tokenizer: Option>, pub reasoning_parser_factory: Option, pub tool_parser_registry: Option<&'static ParserRegistry>, + pub worker_registry: Arc, // Shared worker registry + pub policy_registry: Arc, // Shared policy registry + pub router_manager: Option>, // Only present when enable_igw=true } impl AppContext { @@ -75,6 +82,15 @@ impl AppContext { (None, None, None) }; + // Initialize shared registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new( + router_config.policy.clone(), // Use default policy from config + )); + + // Initialize RouterManager only when enable_igw is true + let router_manager = None; // Will be initialized in startup() based on config + Ok(Self { client, router_config, @@ -82,6 +98,9 @@ impl AppContext { tokenizer, reasoning_parser_factory, tool_parser_registry, + worker_registry, + policy_registry, + router_manager, }) } } @@ -134,7 +153,10 @@ async fn generate( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state.router.route_generate(Some(&headers), &body).await + state + .router + .route_generate(Some(&headers), &body, None) + .await } async fn v1_chat_completions( @@ -142,7 +164,7 @@ async fn v1_chat_completions( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state.router.route_chat(Some(&headers), &body).await + state.router.route_chat(Some(&headers), &body, None).await } async fn v1_completions( @@ -150,7 +172,10 @@ async fn v1_completions( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state.router.route_completion(Some(&headers), &body).await + state + .router + .route_completion(Some(&headers), &body, None) + .await } async fn rerank( @@ -158,7 +183,7 @@ async fn rerank( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state.router.route_rerank(Some(&headers), &body).await + state.router.route_rerank(Some(&headers), &body, None).await } async fn v1_rerank( @@ -168,7 +193,7 @@ async fn v1_rerank( ) -> Response { state .router - .route_rerank(Some(&headers), &body.into()) + .route_rerank(Some(&headers), &body.into(), None) .await } @@ -177,7 +202,10 @@ async fn v1_responses( headers: http::HeaderMap, Json(body): Json, ) -> Response { - state.router.route_responses(Some(&headers), &body).await + state + .router + .route_responses(Some(&headers), &body, None) + .await } // Worker management endpoints @@ -232,6 +260,137 @@ async fn get_loads(State(state): State>, _req: Request) -> Respons state.router.get_worker_loads().await } +// New RESTful worker management endpoints (when enable_igw=true) + +/// POST /workers - Add a new worker with full configuration +async fn create_worker( + State(state): State>, + Json(config): Json, +) -> Response { + // Check if RouterManager is available (enable_igw=true) + if let Some(router_manager) = &state.context.router_manager { + match router_manager.add_worker(config).await { + Ok(response) => (StatusCode::OK, Json(response)).into_response(), + Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), + } + } else { + // In single router mode, use the router's add_worker with basic config + match state.router.add_worker(&config.url).await { + Ok(message) => { + let response = WorkerApiResponse { + success: true, + message, + worker: None, + }; + (StatusCode::OK, Json(response)).into_response() + } + Err(error) => { + let error_response = WorkerErrorResponse { + error, + code: "ADD_WORKER_FAILED".to_string(), + }; + (StatusCode::BAD_REQUEST, Json(error_response)).into_response() + } + } + } +} + +/// GET /workers - List all workers with details +async fn list_workers_rest(State(state): State>) -> Response { + if let Some(router_manager) = &state.context.router_manager { + let response = router_manager.list_workers(); + Json(response).into_response() + } else { + // In single router mode, get detailed worker info from registry + let workers = state.context.worker_registry.get_all(); + let response = serde_json::json!({ + "workers": workers.iter().map(|worker| { + let mut worker_info = serde_json::json!({ + "url": worker.url(), + "model_id": worker.model_id(), + "worker_type": format!("{:?}", worker.worker_type()), + "is_healthy": worker.is_healthy(), + "load": worker.load(), + "connection_mode": format!("{:?}", worker.connection_mode()), + "priority": worker.priority(), + "cost": worker.cost(), + }); + + // Add bootstrap_port for Prefill workers + if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() { + worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); + } + + worker_info + }).collect::>(), + "total": workers.len(), + "stats": { + "prefill_count": state.context.worker_registry.get_prefill_workers().len(), + "decode_count": state.context.worker_registry.get_decode_workers().len(), + "regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(), + } + }); + Json(response).into_response() + } +} + +/// GET /workers/{url} - Get specific worker info +async fn get_worker( + State(state): State>, + axum::extract::Path(url): axum::extract::Path, +) -> Response { + if let Some(router_manager) = &state.context.router_manager { + if let Some(worker) = router_manager.get_worker(&url) { + Json(worker).into_response() + } else { + let error = WorkerErrorResponse { + error: format!("Worker {} not found", url), + code: "WORKER_NOT_FOUND".to_string(), + }; + (StatusCode::NOT_FOUND, Json(error)).into_response() + } + } else { + // In single router mode, check if worker exists + let workers = state.router.get_worker_urls(); + if workers.contains(&url) { + let worker_info = serde_json::json!({ + "url": url, + "model_id": "unknown", + "is_healthy": true + }); + Json(worker_info).into_response() + } else { + let error = WorkerErrorResponse { + error: format!("Worker {} not found", url), + code: "WORKER_NOT_FOUND".to_string(), + }; + (StatusCode::NOT_FOUND, Json(error)).into_response() + } + } +} + +/// DELETE /workers/{url} - Remove a worker +async fn delete_worker( + State(state): State>, + axum::extract::Path(url): axum::extract::Path, +) -> Response { + if let Some(router_manager) = &state.context.router_manager { + match router_manager.remove_worker_from_registry(&url) { + Ok(response) => (StatusCode::OK, Json(response)).into_response(), + Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), + } + } else { + // In single router mode, use router's remove_worker + state.router.remove_worker(&url); + let response = WorkerApiResponse { + success: true, + message: format!("Worker {} removed successfully", url), + worker: None, + }; + (StatusCode::OK, Json(response)).into_response() + } +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -281,11 +440,19 @@ pub fn build_app( .route("/flush_cache", post(flush_cache)) .route("/get_loads", get(get_loads)); + // Worker management routes + let worker_routes = Router::new() + .route("/workers", post(create_worker)) + .route("/workers", get(list_workers_rest)) + .route("/workers/{url}", get(get_worker)) + .route("/workers/{url}", axum::routing::delete(delete_worker)); + // Build app with all routes and middleware Router::new() .merge(protected_routes) .merge(public_routes) .merge(admin_routes) + .merge(worker_routes) // Request body size limiting .layer(tower_http::limit::RequestBodyLimitLayer::new( max_payload_size, @@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box = if config.router_config.enable_igw { + info!("Multi-router mode enabled (enable_igw=true)"); + + // Create RouterManager with shared registries from AppContext + let mut router_manager = RouterManager::new( + config.router_config.clone(), + client.clone(), + app_context.worker_registry.clone(), + app_context.policy_registry.clone(), + ); + + // Create HTTP routers at startup (with empty worker lists) + // Workers will be added to these routers dynamically via RouterManager's worker registry + + // 1. HTTP Regular Router + match RouterFactory::create_regular_router( + &[], // Empty worker list - workers added later + &app_context, + ) + .await + { + Ok(http_regular) => { + info!("Created HTTP Regular router"); + router_manager.register_router( + RouterId::new("http-regular".to_string()), + Arc::from(http_regular), + vec![], // Models will be determined by workers + ); + } + Err(e) => { + warn!("Failed to create HTTP Regular router: {}", e); + } + } + + // 2. HTTP PD Router + match RouterFactory::create_pd_router( + &[], // Empty prefill URLs + &[], // Empty decode URLs + None, // Use default prefill policy + None, // Use default decode policy + &config.router_config.policy, + &app_context, + ) + .await + { + Ok(http_pd) => { + info!("Created HTTP PD router"); + router_manager.register_router( + RouterId::new("http-pd".to_string()), + Arc::from(http_pd), + vec![], + ); + } + Err(e) => { + warn!("Failed to create HTTP PD router: {}", e); + } + } + + // TODO: Add gRPC routers once we have dynamic tokenizer loading + // Currently gRPC routers require tokenizer to be initialized first, + // but each model needs its own tokenizer. Once we implement dynamic + // tokenizer loading per model, we can enable gRPC routers here: + // - RouterType::GrpcRegular (RouterId: "grpc-regular") + // - RouterType::GrpcPd (RouterId: "grpc-pd") + + info!( + "RouterManager initialized with {} routers", + router_manager.router_count() + ); + Box::new(router_manager) + } else { + info!("Single router mode (enable_igw=false)"); + // Create single router with the context + RouterFactory::create_router(&app_context).await? + }; + + // Start health checker for all workers in the registry + let _health_checker = app_context + .worker_registry + .start_health_checker(config.router_config.health_check.check_interval_secs); + info!( + "Started health checker for workers with {}s interval", + config.router_config.health_check.check_interval_secs + ); // Set up concurrency limiter with queue if configured let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new( diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index c27317f86..3392bf705 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -579,9 +579,8 @@ mod tests { // Helper to create a Router instance for testing event handlers async fn create_test_router() -> Arc { - use crate::config::{PolicyConfig, RouterConfig}; + use crate::config::RouterConfig; use crate::middleware::TokenBucket; - use crate::policies::PolicyFactory; use crate::routers::http::router::Router; use crate::server::AppContext; @@ -591,15 +590,19 @@ mod tests { // Create AppContext with minimal components let app_context = Arc::new(AppContext { client: reqwest::Client::new(), - router_config, + router_config: router_config.clone(), rate_limiter: Arc::new(TokenBucket::new(1000, 1000)), + worker_registry: Arc::new(crate::core::WorkerRegistry::new()), + policy_registry: Arc::new(crate::policies::PolicyRegistry::new( + router_config.policy.clone(), + )), tokenizer: None, // HTTP mode doesn't need tokenizer reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser tool_parser_registry: None, // HTTP mode doesn't need tool parser + router_manager: None, // Test doesn't need router manager }); - let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); - let router = Router::new(vec![], policy, &app_context).await.unwrap(); + let router = Router::new(vec![], &app_context).await.unwrap(); Arc::new(router) as Arc } diff --git a/sgl-router/tests/cache_aware_backward_compat_test.rs b/sgl-router/tests/cache_aware_backward_compat_test.rs new file mode 100644 index 000000000..07baa9648 --- /dev/null +++ b/sgl-router/tests/cache_aware_backward_compat_test.rs @@ -0,0 +1,129 @@ +use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; +use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy}; +use std::collections::HashMap; +use std::sync::Arc; + +#[test] +fn test_backward_compatibility_with_empty_model_id() { + let config = CacheAwareConfig { + cache_threshold: 0.5, + balance_abs_threshold: 2, + balance_rel_threshold: 1.5, + eviction_interval_secs: 0, // Disable background eviction for testing + max_tree_size: 100, + }; + + let policy = CacheAwarePolicy::with_config(config); + + // Create workers with empty model_id (simulating existing routers) + let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + // No model_id label - should default to "unknown" + + let mut labels2 = HashMap::new(); + labels2.insert("model_id".to_string(), "unknown".to_string()); + let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) + .with_labels(labels2); + + // Add workers - should both go to "default" tree + policy.add_worker(&worker1); + policy.add_worker(&worker2); + + // Create worker list + let workers: Vec> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())]; + + // Select worker - should work without errors + let selected = policy.select_worker(&workers, Some("test request")); + assert!(selected.is_some(), "Should select a worker"); + + // Remove workers - should work without errors + policy.remove_worker(&worker1); + policy.remove_worker(&worker2); +} + +#[test] +fn test_mixed_model_ids() { + let config = CacheAwareConfig { + cache_threshold: 0.5, + balance_abs_threshold: 2, + balance_rel_threshold: 1.5, + eviction_interval_secs: 0, + max_tree_size: 100, + }; + + let policy = CacheAwarePolicy::with_config(config); + + // Create workers with different model_id scenarios + let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + // No model_id label - defaults to "unknown" which goes to "default" tree + + let mut labels2 = HashMap::new(); + labels2.insert("model_id".to_string(), "llama-3".to_string()); + let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) + .with_labels(labels2); + + let mut labels3 = HashMap::new(); + labels3.insert("model_id".to_string(), "unknown".to_string()); + let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular) + .with_labels(labels3); + + let mut labels4 = HashMap::new(); + labels4.insert("model_id".to_string(), "llama-3".to_string()); + let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular) + .with_labels(labels4); + + // Add all workers + policy.add_worker(&worker1); + policy.add_worker(&worker2); + policy.add_worker(&worker3); + policy.add_worker(&worker4); + + // Test selection with default workers only + let default_workers: Vec> = + vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())]; + let selected = policy.select_worker(&default_workers, Some("test request")); + assert!(selected.is_some(), "Should select from default workers"); + + // Test selection with specific model workers only + let llama_workers: Vec> = + vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())]; + let selected = policy.select_worker(&llama_workers, Some("test request")); + assert!(selected.is_some(), "Should select from llama-3 workers"); + + // Test selection with mixed workers + let all_workers: Vec> = vec![ + Arc::new(worker1.clone()), + Arc::new(worker2.clone()), + Arc::new(worker3.clone()), + Arc::new(worker4.clone()), + ]; + let selected = policy.select_worker(&all_workers, Some("test request")); + assert!(selected.is_some(), "Should select from all workers"); +} + +#[test] +fn test_remove_worker_by_url_backward_compat() { + let config = CacheAwareConfig::default(); + let policy = CacheAwarePolicy::with_config(config); + + // Create workers with different model_ids + let mut labels1 = HashMap::new(); + labels1.insert("model_id".to_string(), "llama-3".to_string()); + let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular) + .with_labels(labels1); + + let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); + // No model_id label - defaults to "unknown" + + // Add workers + policy.add_worker(&worker1); + policy.add_worker(&worker2); + + // Remove by URL (backward compatibility method) + // Should remove from all trees since we don't know the model + policy.remove_worker_by_url("http://worker1:8080"); + + // Verify removal worked + let workers: Vec> = vec![Arc::new(worker2.clone())]; + let selected = policy.select_worker(&workers, Some("test")); + assert_eq!(selected, Some(0), "Should only have worker2 left"); +} diff --git a/sgl-router/tests/policy_registry_integration.rs b/sgl-router/tests/policy_registry_integration.rs new file mode 100644 index 000000000..a90890d53 --- /dev/null +++ b/sgl-router/tests/policy_registry_integration.rs @@ -0,0 +1,168 @@ +//! Integration tests for PolicyRegistry with RouterManager + +use sglang_router_rs::config::{PolicyConfig, RouterConfig}; +use sglang_router_rs::core::WorkerRegistry; +use sglang_router_rs::policies::PolicyRegistry; +use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest; +use sglang_router_rs::routers::router_manager::RouterManager; +use std::collections::HashMap; +use std::sync::Arc; + +#[tokio::test] +async fn test_policy_registry_with_router_manager() { + // Create RouterConfig + let config = RouterConfig { + enable_igw: true, + policy: PolicyConfig::RoundRobin, + ..Default::default() + }; + + // Create HTTP client + let client = reqwest::Client::new(); + + // Create shared registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new(PolicyConfig::RoundRobin)); + + // Create RouterManager with shared registries + let _router_manager = RouterManager::new( + config, + client, + worker_registry.clone(), + policy_registry.clone(), + ); + + // Test adding workers with different models and policies + + // Add first worker for llama-3 with cache_aware policy hint + let mut labels1 = HashMap::new(); + labels1.insert("policy".to_string(), "cache_aware".to_string()); + + let _worker1_config = WorkerConfigRequest { + url: "http://worker1:8000".to_string(), + model_id: Some("llama-3".to_string()), + worker_type: None, + priority: None, + cost: None, + labels: labels1, + bootstrap_port: None, + tokenizer_path: None, + reasoning_parser: None, + tool_parser: None, + chat_template: None, + }; + + // This would normally connect to a real worker, but for testing we'll just verify the structure + // In a real test, we'd need to mock the worker or use a test server + + // Verify PolicyRegistry has the correct policy for llama-3 + let _llama_policy = policy_registry.get_policy("llama-3"); + // After first worker is added, llama-3 should have a policy + + // Add second worker for llama-3 with different policy hint (should be ignored) + let mut labels2 = HashMap::new(); + labels2.insert("policy".to_string(), "random".to_string()); + + let _worker2_config = WorkerConfigRequest { + url: "http://worker2:8000".to_string(), + model_id: Some("llama-3".to_string()), + worker_type: None, + priority: None, + cost: None, + labels: labels2, + bootstrap_port: None, + tokenizer_path: None, + reasoning_parser: None, + tool_parser: None, + chat_template: None, + }; + + // The second worker should use the same policy as the first (cache_aware) + + // Add worker for different model (gpt-4) with random policy + let mut labels3 = HashMap::new(); + labels3.insert("policy".to_string(), "random".to_string()); + + let _worker3_config = WorkerConfigRequest { + url: "http://worker3:8000".to_string(), + model_id: Some("gpt-4".to_string()), + worker_type: None, + priority: None, + cost: None, + labels: labels3, + bootstrap_port: None, + tokenizer_path: None, + reasoning_parser: None, + tool_parser: None, + chat_template: None, + }; + + // Verify gpt-4 has random policy + let _gpt_policy = policy_registry.get_policy("gpt-4"); + + // Test removing workers + // When we remove both llama-3 workers, the policy should be cleaned up + + println!("PolicyRegistry integration test structure created"); + println!("Note: This test requires mocking or test servers to fully execute"); +} + +#[test] +fn test_policy_registry_cleanup() { + use sglang_router_rs::config::PolicyConfig; + use sglang_router_rs::policies::PolicyRegistry; + + let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); + + // Add workers for a model + let policy1 = registry.on_worker_added("model-1", Some("cache_aware")); + assert_eq!(policy1.name(), "cache_aware"); + + // Second worker uses existing policy + let policy2 = registry.on_worker_added("model-1", Some("random")); + assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware + + // Verify policy exists + assert!(registry.get_policy("model-1").is_some()); + + // Remove first worker - policy should remain + registry.on_worker_removed("model-1"); + assert!(registry.get_policy("model-1").is_some()); + + // Remove second worker - policy should be cleaned up + registry.on_worker_removed("model-1"); + assert!(registry.get_policy("model-1").is_none()); + + println!("✓ PolicyRegistry cleanup test passed"); +} + +#[test] +fn test_policy_registry_multiple_models() { + use sglang_router_rs::config::PolicyConfig; + use sglang_router_rs::policies::PolicyRegistry; + + let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); + + // Add workers for different models with different policies + let llama_policy = registry.on_worker_added("llama-3", Some("cache_aware")); + let gpt_policy = registry.on_worker_added("gpt-4", Some("random")); + let mistral_policy = registry.on_worker_added("mistral", None); // Uses default + + assert_eq!(llama_policy.name(), "cache_aware"); + assert_eq!(gpt_policy.name(), "random"); + assert_eq!(mistral_policy.name(), "round_robin"); // Default + + // Verify all policies are stored + assert!(registry.get_policy("llama-3").is_some()); + assert!(registry.get_policy("gpt-4").is_some()); + assert!(registry.get_policy("mistral").is_some()); + + // Get all mappings + let mappings = registry.get_all_mappings(); + assert_eq!(mappings.len(), 3); + assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware"); + assert_eq!(mappings.get("gpt-4").unwrap(), "random"); + assert_eq!(mappings.get("mistral").unwrap(), "round_robin"); + + println!("✓ PolicyRegistry multiple models test passed"); +} diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index ec38a6dd5..366c455f8 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() { rid: None, }; - let response = router.route_generate(None, &generate_request).await; + let response = router.route_generate(None, &generate_request, None).await; assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); // Test completion endpoint (should also not be supported) let completion_request = create_minimal_completion_request(); - let response = router.route_completion(None, &completion_request).await; + let response = router + .route_completion(None, &completion_request, None) + .await; assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); } @@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() { chat_request.temperature = Some(0.7); // Route the request - let response = router.route_chat(None, &chat_request).await; + let response = router.route_chat(None, &chat_request, None).await; // Should get a successful response from mock server assert_eq!(response.status(), StatusCode::OK); @@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() { let chat_request: ChatCompletionRequest = serde_json::from_str(&body_str).unwrap(); - router.route_chat(Some(&parts.headers), &chat_request).await + router + .route_chat(Some(&parts.headers), &chat_request, None) + .await } } }), @@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() { }); let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap(); - let response = router.route_chat(None, &chat_request).await; + let response = router.route_chat(None, &chat_request, None).await; assert_eq!(response.status(), StatusCode::OK); // Should be SSE @@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() { // First few requests should fail and record failures for _ in 0..3 { - let response = router.route_chat(None, &chat_request).await; + let response = router.route_chat(None, &chat_request, None).await; // Should get either an error or circuit breaker response assert!( response.status() == StatusCode::INTERNAL_SERVER_ERROR