diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index fa9d58967..b23b6d7ac 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -30,6 +30,8 @@ tracing-appender = "0.2.3" kube = { version = "0.88.1", features = ["runtime", "derive"] } k8s-openapi = { version = "0.21.0", features = ["v1_29"] } futures = "0.3" +async-trait = "0.1" +once_cell = "1.21" # Added for metrics metrics = "0.24.2" metrics-exporter-prometheus = "0.17.0" diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs new file mode 100644 index 000000000..02a87dbbc --- /dev/null +++ b/sgl-router/src/core/error.rs @@ -0,0 +1,57 @@ +//! Error types for the SGLang router core +//! +//! This module defines error types used throughout the router for worker operations. + +use std::fmt; + +/// Worker-related errors +#[derive(Debug)] +pub enum WorkerError { + /// Health check failed + HealthCheckFailed { url: String, reason: String }, + /// Worker not found + WorkerNotFound { url: String }, + /// Invalid worker configuration + InvalidConfiguration { message: String }, + /// Network error + NetworkError { url: String, error: String }, + /// Worker is at capacity + WorkerAtCapacity { url: String }, +} + +impl fmt::Display for WorkerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WorkerError::HealthCheckFailed { url, reason } => { + write!(f, "Health check failed for worker {}: {}", url, reason) + } + WorkerError::WorkerNotFound { url } => { + write!(f, "Worker not found: {}", url) + } + WorkerError::InvalidConfiguration { message } => { + write!(f, "Invalid worker configuration: {}", message) + } + WorkerError::NetworkError { url, error } => { + write!(f, "Network error for worker {}: {}", url, error) + } + WorkerError::WorkerAtCapacity { url } => { + write!(f, "Worker at capacity: {}", url) + } + } + } +} + +impl std::error::Error for WorkerError {} + +/// Result type for worker operations +pub type WorkerResult = Result; + +/// Convert from reqwest errors to worker errors +impl From for WorkerError { + fn from(err: reqwest::Error) -> Self { + WorkerError::NetworkError { + url: err.url().map(|u| u.to_string()).unwrap_or_default(), + error: err.to_string(), + } + } +} diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs new file mode 100644 index 000000000..aefbc2000 --- /dev/null +++ b/sgl-router/src/core/mod.rs @@ -0,0 +1,16 @@ +//! Core abstractions for the SGLang router +//! +//! This module contains the fundamental types and traits used throughout the router: +//! - Worker trait and implementations +//! - Error types +//! - Common utilities + +pub mod error; +pub mod worker; + +// Re-export commonly used types at the module level +pub use error::{WorkerError, WorkerResult}; +pub use worker::{ + start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory, + WorkerLoadGuard, WorkerType, +}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs new file mode 100644 index 000000000..ae88bdd1c --- /dev/null +++ b/sgl-router/src/core/worker.rs @@ -0,0 +1,454 @@ +use super::{WorkerError, WorkerResult}; +use async_trait::async_trait; +use once_cell::sync::Lazy; +use std::fmt; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +// Shared HTTP client for health checks +static HEALTH_CHECK_CLIENT: Lazy = Lazy::new(|| { + reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request + .build() + .expect("Failed to create health check HTTP client") +}); + +/// Core worker abstraction that represents a backend service +#[async_trait] +pub trait Worker: Send + Sync + fmt::Debug { + /// Get the worker's URL + fn url(&self) -> &str; + + /// Get the worker's type (Regular, Prefill, or Decode) + fn worker_type(&self) -> WorkerType; + + /// Check if the worker is currently healthy + fn is_healthy(&self) -> bool; + + /// Set the worker's health status + fn set_healthy(&self, healthy: bool); + + /// Perform an async health check on the worker + async fn check_health_async(&self) -> WorkerResult<()>; + + /// Synchronous health check wrapper (for compatibility) + fn check_health(&self) -> WorkerResult<()> { + // Use a small runtime for synchronous contexts + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| WorkerError::HealthCheckFailed { + url: self.url().to_string(), + reason: format!("Failed to create runtime: {}", e), + })? + .block_on(self.check_health_async()) + } + + /// Get the current load (number of active requests) + fn load(&self) -> usize; + + /// Increment the load counter + fn increment_load(&self); + + /// Decrement the load counter + fn decrement_load(&self); + + /// Get the number of processed requests + fn processed_requests(&self) -> usize; + + /// Increment the processed requests counter + fn increment_processed(&self); + + /// Get worker-specific metadata + fn metadata(&self) -> &WorkerMetadata; + + /// Clone the worker (for trait objects) + fn clone_worker(&self) -> Box; +} + +/// Worker type classification +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum WorkerType { + /// Regular worker for standard routing + Regular, + /// Prefill worker for PD disaggregated mode + Prefill { + /// Bootstrap port for communication with decode workers + bootstrap_port: Option, + }, + /// Decode worker for PD disaggregated mode + Decode, +} + +impl fmt::Display for WorkerType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WorkerType::Regular => write!(f, "Regular"), + WorkerType::Prefill { bootstrap_port } => match bootstrap_port { + Some(port) => write!(f, "Prefill(bootstrap:{})", port), + None => write!(f, "Prefill"), + }, + WorkerType::Decode => write!(f, "Decode"), + } + } +} + +/// Health check configuration +#[derive(Debug, Clone)] +pub struct HealthConfig { + /// Timeout for health checks in seconds + pub timeout_secs: u64, + /// Interval between health checks in seconds + pub check_interval_secs: u64, + /// Health check endpoint path + pub endpoint: String, +} + +impl Default for HealthConfig { + fn default() -> Self { + Self { + timeout_secs: 5, + check_interval_secs: 30, + endpoint: "/health".to_string(), + } + } +} + +/// Metadata associated with a worker +#[derive(Debug, Clone)] +pub struct WorkerMetadata { + /// Worker URL + pub url: String, + /// Worker type + pub worker_type: WorkerType, + /// Additional labels/tags + pub labels: std::collections::HashMap, + /// Health check configuration + pub health_config: HealthConfig, +} + +/// Basic worker implementation +#[derive(Debug, Clone)] +pub struct BasicWorker { + metadata: WorkerMetadata, + load_counter: Arc, + processed_counter: Arc, + healthy: Arc, +} + +impl BasicWorker { + pub fn new(url: String, worker_type: WorkerType) -> Self { + let metadata = WorkerMetadata { + url: url.clone(), + worker_type, + labels: std::collections::HashMap::new(), + health_config: HealthConfig::default(), + }; + + Self { + metadata, + load_counter: Arc::new(AtomicUsize::new(0)), + processed_counter: Arc::new(AtomicUsize::new(0)), + healthy: Arc::new(AtomicBool::new(true)), + } + } + + pub fn with_labels(mut self, labels: std::collections::HashMap) -> Self { + self.metadata.labels = labels; + self + } + + pub fn with_health_config(mut self, config: HealthConfig) -> Self { + self.metadata.health_config = config; + self + } +} + +#[async_trait] +impl Worker for BasicWorker { + fn url(&self) -> &str { + &self.metadata.url + } + + fn worker_type(&self) -> WorkerType { + self.metadata.worker_type.clone() + } + + fn is_healthy(&self) -> bool { + self.healthy.load(Ordering::Acquire) + } + + fn set_healthy(&self, healthy: bool) { + self.healthy.store(healthy, Ordering::Release); + } + + async fn check_health_async(&self) -> WorkerResult<()> { + use std::time::Duration; + + // Perform actual HTTP health check + let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint); + let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); + + // Use the shared client with a custom timeout for this request + match HEALTH_CHECK_CLIENT + .get(&health_url) + .timeout(timeout) + .send() + .await + { + Ok(response) => { + if response.status().is_success() { + self.set_healthy(true); + Ok(()) + } else { + self.set_healthy(false); + Err(WorkerError::HealthCheckFailed { + url: self.url().to_string(), + reason: format!("Health check returned status: {}", response.status()), + }) + } + } + Err(e) => { + self.set_healthy(false); + Err(WorkerError::HealthCheckFailed { + url: self.url().to_string(), + reason: format!("Health check request failed: {}", e), + }) + } + } + } + + fn load(&self) -> usize { + self.load_counter.load(Ordering::Relaxed) + } + + fn increment_load(&self) { + self.load_counter.fetch_add(1, Ordering::Relaxed); + } + + fn decrement_load(&self) { + self.load_counter + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + current.checked_sub(1) + }) + .ok(); + } + + fn processed_requests(&self) -> usize { + self.processed_counter.load(Ordering::Relaxed) + } + + fn increment_processed(&self) { + self.processed_counter.fetch_add(1, Ordering::Relaxed); + } + + fn metadata(&self) -> &WorkerMetadata { + &self.metadata + } + + fn clone_worker(&self) -> Box { + Box::new(self.clone()) + } +} + +/// Worker factory for creating workers of different types +pub struct WorkerFactory; + +impl WorkerFactory { + /// Create a regular worker + pub fn create_regular(url: String) -> Box { + Box::new(BasicWorker::new(url, WorkerType::Regular)) + } + + /// Create a prefill worker with optional bootstrap port + pub fn create_prefill(url: String, bootstrap_port: Option) -> Box { + Box::new(BasicWorker::new( + url, + WorkerType::Prefill { bootstrap_port }, + )) + } + + /// Create a decode worker + pub fn create_decode(url: String) -> Box { + Box::new(BasicWorker::new(url, WorkerType::Decode)) + } + + /// Create workers from URLs with automatic type detection + pub fn create_from_urls( + regular_urls: Vec, + prefill_urls: Vec<(String, Option)>, + decode_urls: Vec, + ) -> ( + Vec>, + Vec>, + Vec>, + ) { + let regular_workers: Vec> = + regular_urls.into_iter().map(Self::create_regular).collect(); + + let prefill_workers: Vec> = prefill_urls + .into_iter() + .map(|(url, port)| Self::create_prefill(url, port)) + .collect(); + + let decode_workers: Vec> = + decode_urls.into_iter().map(Self::create_decode).collect(); + + (regular_workers, prefill_workers, decode_workers) + } +} + +/// Helper trait for collections of workers +pub trait WorkerCollection { + fn healthy_workers(&self) -> Vec<&dyn Worker>; + fn total_load(&self) -> usize; + fn find_worker(&self, url: &str) -> Option<&dyn Worker>; + fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box>; +} + +impl WorkerCollection for Vec> { + fn healthy_workers(&self) -> Vec<&dyn Worker> { + self.iter() + .filter(|w| w.is_healthy()) + .map(|w| w.as_ref()) + .collect() + } + + fn total_load(&self) -> usize { + self.iter().map(|w| w.load()).sum() + } + + fn find_worker(&self, url: &str) -> Option<&dyn Worker> { + self.iter().find(|w| w.url() == url).map(|w| w.as_ref()) + } + + fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box> { + self.iter_mut().find(|w| w.url() == url) + } +} + +/// Convert a list of worker URLs to worker trait objects +pub fn urls_to_workers(urls: Vec) -> Vec> { + urls.into_iter() + .map(WorkerFactory::create_regular) + .collect() +} + +/// Convert worker trait objects back to URLs +pub fn workers_to_urls(workers: &[Box]) -> Vec { + workers.iter().map(|w| w.url().to_string()).collect() +} + +/// RAII guard for worker load management +pub struct WorkerLoadGuard<'a> { + workers: Vec<&'a dyn Worker>, +} + +impl<'a> WorkerLoadGuard<'a> { + /// Create a new load guard for a single worker + pub fn new(worker: &'a dyn Worker) -> Self { + worker.increment_load(); + Self { + workers: vec![worker], + } + } + + /// Create a new load guard for multiple workers + pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self { + // Increment load counters for all workers + for worker in &workers { + worker.increment_load(); + } + Self { workers } + } +} + +impl<'a> Drop for WorkerLoadGuard<'a> { + fn drop(&mut self) { + // Decrement load counters for all workers + for worker in &self.workers { + worker.decrement_load(); + } + } +} + +/// Health checker handle with graceful shutdown +pub struct HealthChecker { + handle: tokio::task::JoinHandle<()>, + shutdown: Arc, +} + +impl fmt::Debug for HealthChecker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HealthChecker") + .field("shutdown", &self.shutdown.load(Ordering::Relaxed)) + .finish() + } +} + +impl HealthChecker { + /// Shutdown the health checker gracefully + pub async fn shutdown(self) { + self.shutdown.store(true, Ordering::Release); + let _ = self.handle.await; + } +} + +/// Start an async background health checker for a collection of workers +pub fn start_health_checker( + workers: std::sync::Arc>>>, + check_interval_secs: u64, +) -> HealthChecker { + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown_clone = shutdown.clone(); + + let handle = tokio::spawn(async move { + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs)); + + loop { + interval.tick().await; + + // Check for shutdown signal + if shutdown_clone.load(Ordering::Acquire) { + tracing::info!("Health checker shutting down"); + break; + } + + // Check health of all workers + let workers_to_check = match workers.read() { + Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::>(), + Err(poisoned) => { + tracing::error!("Worker lock poisoned: {}", poisoned); + continue; + } + }; + + // Perform health checks concurrently + let health_checks = workers_to_check.iter().map(|worker| { + let worker_url = worker.url().to_string(); + let was_healthy = worker.is_healthy(); + + async move { + match worker.check_health_async().await { + Ok(_) => { + if !was_healthy { + tracing::info!("Worker {} is now healthy", worker_url); + } + } + Err(e) => { + if was_healthy { + tracing::warn!("Worker {} health check failed: {}", worker_url, e); + } + } + } + } + }); + + // Execute all health checks concurrently + futures::future::join_all(health_checks).await; + } + }); + + HealthChecker { handle, shutdown } +} diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index fb90c635e..2b1bcffce 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; pub mod config; pub mod logging; use std::collections::HashMap; +pub mod core; pub mod openai_api_types; pub mod pd_router; pub mod pd_types; diff --git a/sgl-router/src/pd_router.rs b/sgl-router/src/pd_router.rs index 4157395fb..a1f04c7d2 100644 --- a/sgl-router/src/pd_router.rs +++ b/sgl-router/src/pd_router.rs @@ -1,8 +1,9 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems +use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::pd_types::{ - Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy, + api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy, }; use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; @@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt}; use metrics::{counter, histogram}; use serde_json::Value; use std::collections::HashMap; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; @@ -21,49 +21,17 @@ use uuid::Uuid; #[derive(Debug)] pub struct PDRouter { - pub prefill_workers: Arc>>, - pub decode_workers: Arc>>, + pub prefill_workers: Arc>>>, + pub decode_workers: Arc>>>, pub selection_policy: PDSelectionPolicy, - pub load_tracking: Arc>>, pub prefill_tree: Option>>, pub timeout_secs: u64, pub interval_secs: u64, pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, pub http_client: reqwest::Client, -} - -// RAII guard for load tracking to ensure cleanup even on panic -struct LoadGuard<'a> { - tracking: &'a Arc>>, - urls: Vec, -} - -impl<'a> LoadGuard<'a> { - fn new( - tracking: &'a Arc>>, - urls: Vec, - ) -> Self { - // Increment counters - for url in &urls { - let counter = tracking - .entry(url.clone()) - .or_insert_with(|| Arc::new(AtomicUsize::new(0))); - counter.fetch_add(1, Ordering::Relaxed); - } - LoadGuard { tracking, urls } - } -} - -impl Drop for LoadGuard<'_> { - fn drop(&mut self) { - // Guaranteed cleanup even on panic - for url in &self.urls { - if let Some(counter) = self.tracking.get(url) { - counter.fetch_sub(1, Ordering::Relaxed); - } - } - } + _prefill_health_checker: Option, + _decode_health_checker: Option, } impl PDRouter { @@ -73,9 +41,6 @@ impl PDRouter { url: String, bootstrap_port: Option, ) -> Result { - // Create EngineInfo for the new prefill server - let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port); - // Wait for the new server to be healthy crate::router::Router::wait_for_healthy_workers( &[url.clone()], @@ -84,6 +49,9 @@ impl PDRouter { ) .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + // Create Worker for the new prefill server + let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port); + // Add to prefill workers list let mut workers = self .prefill_workers @@ -93,15 +61,11 @@ impl PDRouter { })?; // Check if already exists - if workers.iter().any(|w| w.url == url) { + if workers.iter().any(|w| w.url() == &url) { return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); } - workers.push(engine_info); - - // Initialize load tracking - self.load_tracking - .insert(url.clone(), Arc::new(AtomicUsize::new(0))); + workers.push(worker); // Add to cache tree if using cache-aware policy if let Some(ref tree) = self.prefill_tree { @@ -113,9 +77,6 @@ impl PDRouter { } pub async fn add_decode_server(&self, url: String) -> Result { - // Create EngineInfo for the new decode server - let engine_info = EngineInfo::new_decode(url.clone()); - // Wait for the new server to be healthy crate::router::Router::wait_for_healthy_workers( &[url.clone()], @@ -124,6 +85,9 @@ impl PDRouter { ) .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + // Create Worker for the new decode server + let worker = WorkerFactory::create_decode(url.clone()); + // Add to decode workers list let mut workers = self .decode_workers @@ -133,15 +97,14 @@ impl PDRouter { })?; // Check if already exists - if workers.iter().any(|w| w.url == url) { + if workers.iter().any(|w| w.url() == &url) { return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); } - workers.push(engine_info); + workers.push(worker); // Initialize load tracking - self.load_tracking - .insert(url.clone(), Arc::new(AtomicUsize::new(0))); + // Worker tracks its own load internally info!("Added decode server: {}", url); Ok(format!("Successfully added decode server: {}", url)) @@ -157,7 +120,7 @@ impl PDRouter { // Find and remove the server let initial_len = workers.len(); - workers.retain(|w| w.url != url); + workers.retain(|w| w.url() != url); if workers.len() == initial_len { return Err(PDRouterError::WorkerNotFound { @@ -166,7 +129,7 @@ impl PDRouter { } // Remove from load tracking - self.load_tracking.remove(url); + // Worker load tracking is internal // Remove from cache tree if using cache-aware policy if let Some(ref tree) = self.prefill_tree { @@ -174,7 +137,7 @@ impl PDRouter { let mut tree_guard = tree.lock().unwrap(); *tree_guard = Tree::new(); for worker in workers.iter() { - tree_guard.insert("", &worker.url); + tree_guard.insert("", worker.url()); } } @@ -192,7 +155,7 @@ impl PDRouter { // Find and remove the server let initial_len = workers.len(); - workers.retain(|w| w.url != url); + workers.retain(|w| w.url() != url); if workers.len() == initial_len { return Err(PDRouterError::WorkerNotFound { @@ -200,9 +163,6 @@ impl PDRouter { }); } - // Remove from load tracking - self.load_tracking.remove(url); - info!("Removed decode server: {}", url); Ok(format!("Successfully removed decode server: {}", url)) } @@ -214,41 +174,32 @@ impl PDRouter { timeout_secs: u64, interval_secs: u64, ) -> Result { - // Convert URLs to EngineInfo - let prefill_workers: Vec = prefill_urls + // Convert URLs to Worker trait objects + let prefill_workers: Vec> = prefill_urls .into_iter() - .map(|(url, port)| EngineInfo::new_prefill(url, port)) + .map(|(url, port)| WorkerFactory::create_prefill(url, port)) .collect(); - let decode_workers: Vec = decode_urls + let decode_workers: Vec> = decode_urls .into_iter() - .map(EngineInfo::new_decode) + .map(WorkerFactory::create_decode) .collect(); // Wait for PD workers to be healthy let all_urls: Vec = prefill_workers .iter() .chain(decode_workers.iter()) - .map(|engine| engine.url.clone()) + .map(|worker| worker.url().to_string()) .collect(); crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?; - // Initialize load tracking with atomic counters - let load_tracking = Arc::new(dashmap::DashMap::new()); - for engine in &prefill_workers { - load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0))); - } - for engine in &decode_workers { - load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0))); - } - // Initialize cache-aware components if needed let prefill_tree = match &selection_policy { PDSelectionPolicy::CacheAware { .. } => { let tree = Arc::new(Mutex::new(Tree::new())); // Initialize tree with prefill workers - for engine in &prefill_workers { - tree.lock().unwrap().insert("", &engine.url); + for worker in &prefill_workers { + tree.lock().unwrap().insert("", worker.url()); } Some(tree) } @@ -283,17 +234,27 @@ impl PDRouter { None }; + let prefill_workers = Arc::new(RwLock::new(prefill_workers)); + let decode_workers = Arc::new(RwLock::new(decode_workers)); + + // Start health checkers for both worker pools + let prefill_health_checker = + crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs); + let decode_health_checker = + crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs); + Ok(PDRouter { - prefill_workers: Arc::new(RwLock::new(prefill_workers)), - decode_workers: Arc::new(RwLock::new(decode_workers)), + prefill_workers, + decode_workers, selection_policy, - load_tracking, prefill_tree, timeout_secs, interval_secs, worker_loads, load_monitor_handle, http_client, + _prefill_health_checker: Some(prefill_health_checker), + _decode_health_checker: Some(decode_health_checker), }) } @@ -330,11 +291,13 @@ impl PDRouter { // Log routing decision info!( "PD routing: {} -> prefill={}, decode={}", - route, prefill.url, decode.url + route, + prefill.url(), + decode.url() ); // Add bootstrap info using the trait method - if let Err(e) = typed_req.add_bootstrap_info(&prefill) { + if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { error!("Failed to add bootstrap info: {}", e); counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); return HttpResponse::InternalServerError() @@ -356,8 +319,8 @@ impl PDRouter { req, json_with_bootstrap, route, - &prefill, - &decode, + prefill.as_ref(), + decode.as_ref(), is_stream, return_logprob, start, @@ -397,11 +360,13 @@ impl PDRouter { // Log routing decision info!( "PD routing: {} -> prefill={}, decode={}", - route, prefill.url, decode.url + route, + prefill.url(), + decode.url() ); // Add bootstrap info using the trait method - if let Err(e) = typed_req.add_bootstrap_info(&prefill) { + if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { error!("Failed to add bootstrap info: {}", e); counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); return HttpResponse::InternalServerError() @@ -423,8 +388,8 @@ impl PDRouter { req, json_with_bootstrap, route, - &prefill, - &decode, + prefill.as_ref(), + decode.as_ref(), is_stream, return_logprob, start, @@ -440,22 +405,23 @@ impl PDRouter { req: &HttpRequest, json_request: serde_json::Value, route: &str, - prefill: &EngineInfo, - decode: &EngineInfo, + prefill: &dyn Worker, + decode: &dyn Worker, is_stream: bool, return_logprob: bool, start_time: Instant, ) -> HttpResponse { // Update load tracking for both workers - let _guard = LoadGuard::new( - &self.load_tracking, - vec![prefill.url.clone(), decode.url.clone()], - ); + let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); // Build requests using .json() method - let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request); + let mut prefill_request = client + .post(api_path(prefill.url(), route)) + .json(&json_request); - let mut decode_request = client.post(decode.api_path(route)).json(&json_request); + let mut decode_request = client + .post(api_path(decode.url(), route)) + .json(&json_request); // Copy headers from original request for (name, value) in crate::router::copy_request_headers(req) { @@ -474,9 +440,9 @@ impl PDRouter { histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string()) .record(duration.as_secs_f64()); counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1); - counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string()) + counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string()) .increment(1); - counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string()) + counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string()) .increment(1); // Process decode response @@ -486,10 +452,11 @@ impl PDRouter { .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); if !status.is_success() { - counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1); + counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1); error!( "Decode server {} returned error status: {}", - decode.url, status + decode.url(), + status ); // Return the error response from decode server @@ -508,9 +475,10 @@ impl PDRouter { if let Err(e) = &prefill_result { error!( "Prefill server {} failed (non-critical): {}", - prefill.url, e + prefill.url(), + e ); - counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1); + counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1); } if is_stream { @@ -559,7 +527,7 @@ impl PDRouter { HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming({ - let decode_url = decode.url.clone(); + let decode_url = decode.url().to_string(); res.bytes_stream().map_err(move |e| { error!("Stream error from decode server {}: {}", decode_url, e); counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1); @@ -587,7 +555,7 @@ impl PDRouter { } Err(e) => { error!("Decode request failed: {}", e); - counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()) + counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()) .increment(1); HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) } @@ -652,7 +620,7 @@ impl PDRouter { async fn select_pd_pair( &self, _client: &reqwest::Client, - ) -> Result<(EngineInfo, EngineInfo), String> { + ) -> Result<(Box, Box), String> { // Check we have workers if self .prefill_workers @@ -681,17 +649,17 @@ impl PDRouter { } } - fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> { + fn select_random(&self) -> Result<(Box, Box), String> { let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; - let prefill = prefill_list[rand::random::() % prefill_list.len()].clone(); - let decode = decode_list[rand::random::() % decode_list.len()].clone(); + let prefill = prefill_list[rand::random::() % prefill_list.len()].clone_worker(); + let decode = decode_list[rand::random::() % decode_list.len()].clone_worker(); Ok((prefill, decode)) } - async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> { + async fn select_power_of_two(&self) -> Result<(Box, Box), String> { let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; @@ -700,33 +668,45 @@ impl PDRouter { let loads = self.worker_loads.borrow(); - let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0); - let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0); - let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0); - let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0); + let p1_load = loads + .get(prefill_list[p1_idx].url()) + .copied() + .unwrap_or(isize::MAX); + let p2_load = loads + .get(prefill_list[p2_idx].url()) + .copied() + .unwrap_or(isize::MAX); + let d1_load = loads + .get(decode_list[d1_idx].url()) + .copied() + .unwrap_or(isize::MAX); + let d2_load = loads + .get(decode_list[d2_idx].url()) + .copied() + .unwrap_or(isize::MAX); info!( "Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}", - prefill_list[p1_idx].url, + prefill_list[p1_idx].url(), p1_load, - prefill_list[p2_idx].url, + prefill_list[p2_idx].url(), p2_load, - decode_list[d1_idx].url, + decode_list[d1_idx].url(), d1_load, - decode_list[d2_idx].url, + decode_list[d2_idx].url(), d2_load ); let selected_prefill = if p1_load <= p2_load { - prefill_list[p1_idx].clone() + prefill_list[p1_idx].clone_worker() } else { - prefill_list[p2_idx].clone() + prefill_list[p2_idx].clone_worker() }; let selected_decode = if d1_load <= d2_load { - decode_list[d1_idx].clone() + decode_list[d1_idx].clone_worker() } else { - decode_list[d2_idx].clone() + decode_list[d2_idx].clone_worker() }; Ok((selected_prefill, selected_decode)) @@ -868,11 +848,11 @@ impl PDRouter { let mut worker_infos = Vec::new(); for worker in self.prefill_workers.read().unwrap().iter() { - worker_infos.push((worker.url.clone(), "prefill")); + worker_infos.push((worker.url().to_string(), "prefill")); } for worker in self.decode_workers.read().unwrap().iter() { - worker_infos.push((worker.url.clone(), "decode")); + worker_infos.push((worker.url().to_string(), "decode")); } // Create tasks with URL tracking @@ -922,7 +902,7 @@ impl PDRouter { pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse { // Get info from the first decode server to match sglang's server info format let first_decode_url = if let Ok(workers) = self.decode_workers.read() { - workers.first().map(|w| w.url.clone()) + workers.first().map(|w| w.url().to_string()) } else { return HttpResponse::InternalServerError().body("Failed to access decode workers"); }; @@ -967,7 +947,7 @@ impl PDRouter { pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse { // Get first prefill worker URL to avoid holding lock across await let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url.clone()) + workers.first().map(|w| w.url().to_string()) } else { return HttpResponse::InternalServerError().body("Failed to access prefill workers"); }; @@ -1005,14 +985,14 @@ impl PDRouter { .read() .unwrap() .iter() - .map(|w| w.url.clone()) + .map(|w| w.url().to_string()) .collect(); let d_urls: Vec<_> = self .decode_workers .read() .unwrap() .iter() - .map(|w| w.url.clone()) + .map(|w| w.url().to_string()) .collect(); let mut prefill_loads = Vec::new(); @@ -1048,7 +1028,7 @@ impl PDRouter { // Get model info from the first prefill server (matches original Rust PDLB behavior) // Get first prefill worker URL to avoid holding lock across await let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url.clone()) + workers.first().map(|w| w.url().to_string()) } else { return HttpResponse::InternalServerError().body("Failed to access prefill workers"); }; @@ -1084,13 +1064,13 @@ impl PDRouter { // Flush cache on all prefill servers for worker in self.prefill_workers.read().unwrap().iter() { - let url = format!("{}/flush_cache", worker.url); + let url = format!("{}/flush_cache", worker.url()); tasks.push(client.post(&url).send()); } // Flush cache on all decode servers for worker in self.decode_workers.read().unwrap().iter() { - let url = format!("{}/flush_cache", worker.url); + let url = format!("{}/flush_cache", worker.url()); tasks.push(client.post(&url).send()); } diff --git a/sgl-router/src/pd_types.rs b/sgl-router/src/pd_types.rs index 16dc18267..75473b0e3 100644 --- a/sgl-router/src/pd_types.rs +++ b/sgl-router/src/pd_types.rs @@ -1,5 +1,6 @@ // Essential PDLB types extracted for PD routing +use crate::core::{Worker, WorkerType}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -28,52 +29,21 @@ pub enum PDRouterError { Timeout { url: String }, } -#[derive(Debug, Clone)] -pub enum EngineType { - Prefill, - Decode, +// Helper functions for workers +pub fn api_path(url: &str, api_path: &str) -> String { + if api_path.starts_with("/") { + format!("{}{}", url, api_path) + } else { + format!("{}/{}", url, api_path) + } } -#[derive(Debug, Clone)] -pub struct EngineInfo { - pub engine_type: EngineType, - pub url: String, - pub bootstrap_port: Option, -} - -impl EngineInfo { - pub fn new_prefill(url: String, bootstrap_port: Option) -> Self { - EngineInfo { - engine_type: EngineType::Prefill, - url, - bootstrap_port, - } - } - - pub fn new_decode(url: String) -> Self { - EngineInfo { - engine_type: EngineType::Decode, - url, - bootstrap_port: None, - } - } - - pub fn api_path(&self, api_path: &str) -> String { - if api_path.starts_with("/") { - format!("{}{}", self.url, api_path) - } else { - format!("{}/{}", self.url, api_path) - } - } - - pub fn get_hostname(&self) -> String { - // Simple hostname extraction without external dependencies - let url = self - .url - .trim_start_matches("http://") - .trim_start_matches("https://"); - url.split(':').next().unwrap_or("localhost").to_string() - } +pub fn get_hostname(url: &str) -> String { + // Simple hostname extraction without external dependencies + let url = url + .trim_start_matches("http://") + .trim_start_matches("https://"); + url.split(':').next().unwrap_or("localhost").to_string() } // PD-specific routing policies @@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync { bootstrap_room: BootstrapRoom, ); - fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> { + fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> { let batch_size = self.get_batch_size()?; + + // Extract bootstrap port from prefill worker if it's a prefill type + let bootstrap_port = match prefill_worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + + let hostname = get_hostname(prefill_worker.url()); + if let Some(batch_size) = batch_size { self.set_bootstrap_info( - BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), - BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), + BootstrapHost::Batch(vec![hostname; batch_size]), + BootstrapPort::Batch(vec![bootstrap_port; batch_size]), // Use high-quality random numbers to minimize collision risk BootstrapRoom::Batch( (0..batch_size) @@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync { ); } else { self.set_bootstrap_info( - BootstrapHost::Single(prefill_info.get_hostname()), - BootstrapPort::Single(prefill_info.bootstrap_port), + BootstrapHost::Single(hostname), + BootstrapPort::Single(bootstrap_port), BootstrapRoom::Single({ // Use high-quality random number for single requests too let r1 = rand::random::(); diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index f04653855..e8b68d7c5 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -1,3 +1,4 @@ +use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::pd_router::PDRouter; use crate::pd_types::PDSelectionPolicy; use crate::tree::Tree; @@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use futures_util::{StreamExt, TryStreamExt}; -use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex, RwLock}; @@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { #[derive(Debug)] pub enum Router { RoundRobin { - worker_urls: Arc>>, + workers: Arc>>>, current_index: AtomicUsize, timeout_secs: u64, interval_secs: u64, + _health_checker: Option, }, Random { - worker_urls: Arc>>, + workers: Arc>>>, timeout_secs: u64, interval_secs: u64, + _health_checker: Option, }, PrefillDecode { pd_router: Arc, @@ -104,16 +106,15 @@ pub enum Router { Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ - worker_urls: Arc>>, + workers: Arc>>>, tree: Arc>, - running_queue: Arc>>, - processed_queue: Arc>>, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, timeout_secs: u64, interval_secs: u64, _eviction_thread: Option>, + _health_checker: Option, }, } @@ -192,25 +193,43 @@ impl Router { } } + // Create Worker trait objects from URLs + let workers: Vec> = worker_urls + .iter() + .map(|url| WorkerFactory::create_regular(url.clone())) + .collect(); + // Create router based on policy... Ok(match policy_config { PolicyConfig::RandomConfig { timeout_secs, interval_secs, - } => Router::Random { - worker_urls: Arc::new(RwLock::new(worker_urls)), - timeout_secs, - interval_secs, - }, + } => { + let workers = Arc::new(RwLock::new(workers)); + let health_checker = + crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + Router::Random { + workers, + timeout_secs, + interval_secs, + _health_checker: Some(health_checker), + } + } PolicyConfig::RoundRobinConfig { timeout_secs, interval_secs, - } => Router::RoundRobin { - worker_urls: Arc::new(RwLock::new(worker_urls)), - current_index: std::sync::atomic::AtomicUsize::new(0), - timeout_secs, - interval_secs, - }, + } => { + let workers = Arc::new(RwLock::new(workers)); + let health_checker = + crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + Router::RoundRobin { + workers, + current_index: std::sync::atomic::AtomicUsize::new(0), + timeout_secs, + interval_secs, + _health_checker: Some(health_checker), + } + } PolicyConfig::CacheAwareConfig { cache_threshold, balance_abs_threshold, @@ -220,24 +239,12 @@ impl Router { timeout_secs, interval_secs, } => { - let mut running_queue = HashMap::new(); - for url in &worker_urls { - running_queue.insert(url.clone(), 0); - } - - let mut processed_queue = HashMap::new(); - for url in &worker_urls { - processed_queue.insert(url.clone(), 0); - } - let tree = Arc::new(Mutex::new(Tree::new())); - let running_queue = Arc::new(Mutex::new(running_queue)); - let processed_queue = Arc::new(Mutex::new(processed_queue)); // Create background eviction thread let tree_clone = Arc::clone(&tree); - let processed_queue_clone = Arc::clone(&processed_queue); - let running_queue_clone = Arc::clone(&running_queue); + let workers = Arc::new(RwLock::new(workers)); + let workers_clone = Arc::clone(&workers); let eviction_thread = thread::spawn(move || { loop { // Sleep for the specified interval @@ -246,32 +253,41 @@ impl Router { let locked_tree_clone = tree_clone.lock().unwrap(); // Run eviction locked_tree_clone.evict_tenant_by_size(max_tree_size); + drop(locked_tree_clone); - // Print the process queue - let locked_processed_queue = processed_queue_clone.lock().unwrap(); - info!("Processed Queue: {:?}", locked_processed_queue); + // Log worker loads and processed requests + let workers_guard = workers_clone.read().unwrap(); + let loads: Vec<(String, usize)> = workers_guard + .iter() + .map(|w| (w.url().to_string(), w.load())) + .collect(); + info!("Worker loads: {:?}", loads); - // Print the running queue - let locked_running_queue = running_queue_clone.lock().unwrap(); - info!("Running Queue: {:?}", locked_running_queue); + let processed: Vec<(String, usize)> = workers_guard + .iter() + .map(|w| (w.url().to_string(), w.processed_requests())) + .collect(); + info!("Processed requests: {:?}", processed); } }); - for url in &worker_urls { - tree.lock().unwrap().insert("", url); + for worker in workers.read().unwrap().iter() { + tree.lock().unwrap().insert("", worker.url()); } + let health_checker = + crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + Router::CacheAware { - worker_urls: Arc::new(RwLock::new(worker_urls)), + workers, tree, - running_queue, - processed_queue, cache_threshold, balance_abs_threshold, balance_rel_threshold, timeout_secs, interval_secs, _eviction_thread: Some(eviction_thread), + _health_checker: Some(health_checker), } } PolicyConfig::PrefillDecodeConfig { @@ -297,16 +313,18 @@ impl Router { }) } - /// Get a reference to the worker URLs shared across threads - pub fn get_worker_urls(&self) -> Arc>> { + /// Get the current list of worker URLs + pub fn get_worker_urls(&self) -> Vec { match self { - Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls), - Router::Random { worker_urls, .. } => Arc::clone(worker_urls), - Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls), - Router::PrefillDecode { .. } => { - // For PD mode, return empty list since we manage workers differently - Arc::new(RwLock::new(Vec::new())) - } + Router::RoundRobin { workers, .. } + | Router::Random { workers, .. } + | Router::CacheAware { workers, .. } => workers + .read() + .unwrap() + .iter() + .map(|w| w.url().to_string()) + .collect(), + Router::PrefillDecode { .. } => Vec::new(), } } @@ -373,13 +391,14 @@ impl Router { fn select_first_worker(&self) -> Result { match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls, .. } - | Router::CacheAware { worker_urls, .. } => { - if worker_urls.read().unwrap().is_empty() { + Router::RoundRobin { workers, .. } + | Router::Random { workers, .. } + | Router::CacheAware { workers, .. } => { + let workers_guard = workers.read().unwrap(); + if workers_guard.is_empty() { Err("No workers are available".to_string()) } else { - Ok(worker_urls.read().unwrap()[0].clone()) + Ok(workers_guard[0].url().to_string()) } } Router::PrefillDecode { .. } => { @@ -514,7 +533,7 @@ impl Router { return HttpResponse::NotImplemented() .body("route_to_all not implemented for PrefillDecode mode"); } - _ => self.get_worker_urls().read().unwrap().clone(), + _ => self.get_worker_urls(), }; // Send requests to all workers concurrently @@ -562,7 +581,7 @@ impl Router { } } - let urls = self.get_worker_urls().read().unwrap().clone(); + let urls = self.get_worker_urls(); let prefill_urls: Vec = Vec::new(); let decode_urls = urls; @@ -631,6 +650,24 @@ impl Router { .increment(1); } + // For CacheAware router, increment load before request + let load_incremented = match self { + Router::CacheAware { workers, .. } => { + let workers_guard = workers.read().unwrap(); + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == &worker_url) + { + worker.increment_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + true + } else { + false + } + } + _ => false, + }; + // Send typed request directly let response = self .send_typed_request( @@ -640,6 +677,7 @@ impl Router { route, &worker_url, is_stream, + load_incremented, ) .await; @@ -684,44 +722,47 @@ impl Router { } } - // Helper method to select worker from text + // Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware) fn select_generate_worker_from_text(&self, text: &str) -> String { match self { Router::RoundRobin { - worker_urls, + workers, current_index, .. } => { + let workers_guard = workers.read().unwrap(); let idx = current_index .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.read().unwrap().len()), + |x| Some((x + 1) % workers_guard.len()), ) .unwrap(); - worker_urls.read().unwrap()[idx].clone() + workers_guard[idx].url().to_string() } - Router::Random { worker_urls, .. } => worker_urls.read().unwrap() - [rand::random::() % worker_urls.read().unwrap().len()] - .clone(), + Router::Random { workers, .. } => { + let workers_guard = workers.read().unwrap(); + workers_guard[rand::random::() % workers_guard.len()] + .url() + .to_string() + } Router::CacheAware { - worker_urls, + workers, tree, - running_queue, - processed_queue, cache_threshold, balance_abs_threshold, balance_rel_threshold, .. } => { let tree = tree.lock().unwrap(); - let mut running_queue = running_queue.lock().unwrap(); + let workers_guard = workers.read().unwrap(); - // Get current load statistics - let max_load = *running_queue.values().max().unwrap_or(&0); - let min_load = *running_queue.values().min().unwrap_or(&0); + // Get current load statistics from workers + let loads: Vec = workers_guard.iter().map(|w| w.load()).collect(); + let max_load = *loads.iter().max().unwrap_or(&0); + let min_load = *loads.iter().min().unwrap_or(&0); // Load is considered imbalanced if: // 1. (max - min) > abs_threshold AND @@ -731,11 +772,16 @@ impl Router { let selected_url = if is_imbalanced { // Log load balancing trigger and current queue state + let worker_loads: Vec<(String, usize)> = workers_guard + .iter() + .map(|w| (w.url().to_string(), w.load())) + .collect(); + info!( "Load balancing triggered due to workload imbalance:\n\ Max load: {}, Min load: {}\n\ - Current running queue: {:?}", - max_load, min_load, running_queue + Current worker loads: {:?}", + max_load, min_load, worker_loads ); counter!("sgl_router_load_balancing_events_total").increment(1); @@ -743,11 +789,11 @@ impl Router { gauge!("sgl_router_min_load").set(min_load as f64); // Use shortest queue routing when load is imbalanced - running_queue + workers_guard .iter() - .min_by_key(|(_url, &count)| count) - .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) + .min_by_key(|w| w.load()) + .map(|w| w.url().to_string()) + .unwrap_or_else(|| workers_guard[0].url().to_string()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); @@ -763,18 +809,12 @@ impl Router { } }; - // Update queues and tree - *running_queue.get_mut(&selected_url).unwrap() += 1; - - *processed_queue - .lock() - .unwrap() - .get_mut(&selected_url) - .unwrap() += 1; - - gauge!("sgl_router_running_requests", "worker" => selected_url.to_string()) - .set(*running_queue.get(&selected_url).unwrap() as f64); - counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1); + // Find the selected worker and increment processed counter only + if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) { + worker.increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()) + .increment(1); + } tree.insert(&text, &selected_url); @@ -796,6 +836,7 @@ impl Router { route: &str, worker_url: &str, is_stream: bool, + load_incremented: bool, // Whether load was incremented for this request ) -> HttpResponse { let start = Instant::now(); @@ -820,6 +861,22 @@ impl Router { Ok(res) => res, Err(e) => { error!("Failed to send request to {}: {}", worker_url, e); + + // Decrement load on error for CacheAware router + if load_incremented { + if let Router::CacheAware { workers, .. } = self { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == worker_url) + { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } + } + } + } + return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); } }; @@ -837,13 +894,15 @@ impl Router { } }; - // Then decrement running queue counter if using CacheAware - if let Router::CacheAware { running_queue, .. } = self { - if let Ok(mut queue) = running_queue.lock() { - if let Some(count) = queue.get_mut(worker_url) { - *count = count.saturating_sub(1); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(*count as f64); + // Decrement load counter for non-streaming CacheAware requests + if load_incremented && !is_stream { + if let Router::CacheAware { workers, .. } = self { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } } } } @@ -855,8 +914,9 @@ impl Router { counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); response - } else if let Router::CacheAware { running_queue, .. } = self { - let running_queue = Arc::clone(running_queue); + } else if let Router::CacheAware { workers, .. } = self { + // For streaming with CacheAware router, we need to manually decrement when done + let workers = Arc::clone(workers); let worker_url = worker_url.to_string(); HttpResponse::build(status) @@ -867,21 +927,28 @@ impl Router { actix_web::error::ErrorInternalServerError("Failed to read stream") }) .inspect(move |bytes| { - let bytes = bytes.as_ref().unwrap(); - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - let mut locked_queue = running_queue.lock().unwrap(); - let count = locked_queue.get_mut(&worker_url).unwrap(); - *count = count.saturating_sub(1); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()).set(*count as f64); - debug!("Streaming is done!!") + if let Ok(bytes) = bytes { + if bytes + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]") + { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == &worker_url) + { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + debug!("Streaming is done!!") + } + } + } } }), ) } else { + // For non-CacheAware routers, just stream without load tracking HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming(res.bytes_stream().map_err(|_| { @@ -935,43 +1002,27 @@ impl Router { Ok(res) => { if res.status().is_success() { match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls, .. } - | Router::CacheAware { worker_urls, .. } => { + Router::RoundRobin { workers, .. } + | Router::Random { workers, .. } + | Router::CacheAware { workers, .. } => { info!("Worker {} health check passed", worker_url); - let mut urls = worker_urls.write().unwrap(); - if urls.contains(&worker_url.to_string()) { + let mut workers_guard = workers.write().unwrap(); + if workers_guard.iter().any(|w| w.url() == worker_url) { return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); - urls.push(worker_url.to_string()); - gauge!("sgl_router_active_workers").set(urls.len() as f64); + let new_worker = + WorkerFactory::create_regular(worker_url.to_string()); + workers_guard.push(new_worker); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); } Router::PrefillDecode { .. } => { return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); } } - // If cache aware, initialize the queues for the new worker - if let Router::CacheAware { - running_queue, - processed_queue, - tree, - .. - } = self - { - // Add worker to running queue with initial count of 0 - running_queue - .lock() - .unwrap() - .insert(worker_url.to_string(), 0); - - // Add worker to processed queue with initial count of 0 - processed_queue - .lock() - .unwrap() - .insert(worker_url.to_string(), 0); - + // If cache aware, add worker to tree + if let Router::CacheAware { tree, .. } = self { // Add worker to tree tree.lock().unwrap().insert("", worker_url); } @@ -1013,14 +1064,14 @@ impl Router { pub fn remove_worker(&self, worker_url: &str) { match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls, .. } - | Router::CacheAware { worker_urls, .. } => { - let mut urls = worker_urls.write().unwrap(); - if let Some(index) = urls.iter().position(|url| url == &worker_url) { - urls.remove(index); + Router::RoundRobin { workers, .. } + | Router::Random { workers, .. } + | Router::CacheAware { workers, .. } => { + let mut workers_guard = workers.write().unwrap(); + if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { + workers_guard.remove(index); info!("Removed worker: {}", worker_url); - gauge!("sgl_router_active_workers").set(urls.len() as f64); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); } else { warn!("Worker {} not found, skipping removal", worker_url); return; @@ -1033,26 +1084,9 @@ impl Router { } // if cache aware, remove the worker from the tree - if let Router::CacheAware { - tree, - running_queue, - processed_queue, - .. - } = self - { + if let Router::CacheAware { tree, .. } = self { tree.lock().unwrap().remove_tenant(&worker_url); - running_queue - .lock() - .unwrap() - .remove(&worker_url.to_string()); - processed_queue - .lock() - .unwrap() - .remove(&worker_url.to_string()); - info!( - "Removed worker from tree and cleaned up queues: {}", - worker_url - ); + info!("Removed worker from tree: {}", worker_url); } } @@ -1241,21 +1275,22 @@ mod tests { use crate::service_discovery::PodType; fn create_test_regular_router() -> Router { + let workers = vec![ + WorkerFactory::create_regular("http://worker1:8080".to_string()), + WorkerFactory::create_regular("http://worker2:8080".to_string()), + ]; Router::Random { - worker_urls: Arc::new(RwLock::new(vec![ - "http://worker1:8080".to_string(), - "http://worker2:8080".to_string(), - ])), + workers: Arc::new(RwLock::new(workers)), timeout_secs: 5, interval_secs: 1, + _health_checker: None, } } #[test] fn test_router_get_worker_urls_regular() { let router = create_test_regular_router(); - let worker_urls = router.get_worker_urls(); - let urls = worker_urls.read().unwrap(); + let urls = router.get_worker_urls(); assert_eq!(urls.len(), 2); assert!(urls.contains(&"http://worker1:8080".to_string())); diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index e87405517..bb2695b93 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -236,8 +236,7 @@ async fn add_worker( #[get("/list_workers")] async fn list_workers(data: web::Data) -> impl Responder { - let workers = data.router.get_worker_urls(); - let worker_list = workers.read().unwrap().clone(); + let worker_list = data.router.get_worker_urls(); HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list })) } @@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { info!("✅ Serving router on {}:{}", config.host, config.port); info!( "✅ Serving workers on {:?}", - app_state.router.get_worker_urls().read().unwrap() + app_state.router.get_worker_urls() ); HttpServer::new(move || { diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index b7104de11..0e78717ce 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -547,11 +547,12 @@ mod tests { // Helper to create a Router instance for testing event handlers fn create_test_router() -> Arc { - let worker_urls = Arc::new(RwLock::new(Vec::new())); + let workers = Arc::new(RwLock::new(Vec::new())); Arc::new(Router::Random { - worker_urls, + workers, timeout_secs: 5, interval_secs: 1, + _health_checker: None, }) } @@ -878,8 +879,6 @@ mod tests { assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); assert!(!router .get_worker_urls() - .read() - .unwrap() .contains(&pod_info.worker_url(port))); } @@ -907,7 +906,7 @@ mod tests { .await; assert!(tracked_pods.lock().unwrap().is_empty()); - assert!(router.get_worker_urls().read().unwrap().is_empty()); + assert!(router.get_worker_urls().is_empty()); } #[tokio::test] diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 5a1e65790..02b8c99f5 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -12,7 +12,7 @@ mod test_pd_routing { use rand::Rng; use serde_json::json; - use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy}; + use sglang_router_rs::pd_types::PDSelectionPolicy; use sglang_router_rs::router::{PolicyConfig, Router}; // Test-only struct to help validate PD request parsing @@ -51,40 +51,35 @@ mod test_pd_routing { // ======================================================================== #[test] - fn test_engine_info_creation() { - // Test EngineInfo creation for prefill servers - let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); - match prefill_engine.engine_type { - EngineType::Prefill => (), - _ => panic!("Expected Prefill engine type"), - } - assert_eq!(prefill_engine.url, "http://prefill:8080"); - assert_eq!(prefill_engine.bootstrap_port, Some(9000)); - assert_eq!(prefill_engine.get_hostname(), "prefill"); + fn test_worker_types() { + use sglang_router_rs::core::{WorkerFactory, WorkerType}; - // Test EngineInfo creation for decode servers - let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string()); - match decode_engine.engine_type { - EngineType::Decode => (), - _ => panic!("Expected Decode engine type"), + // Test worker creation for prefill servers + let prefill_worker = + WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); + assert_eq!(prefill_worker.url(), "http://prefill:8080"); + match prefill_worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => { + assert_eq!(bootstrap_port, Some(9000)); + } + _ => panic!("Expected Prefill worker type"), } - assert_eq!(decode_engine.url, "http://decode:8080"); - assert_eq!(decode_engine.bootstrap_port, None); - assert_eq!(decode_engine.get_hostname(), "decode"); - // Test API path generation - assert_eq!( - prefill_engine.api_path("/generate"), - "http://prefill:8080/generate" - ); - assert_eq!( - prefill_engine.api_path("health"), - "http://prefill:8080/health" - ); - assert_eq!( - decode_engine.api_path("/v1/chat/completions"), - "http://decode:8080/v1/chat/completions" - ); + // Test worker creation for decode servers + let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string()); + assert_eq!(decode_worker.url(), "http://decode:8080"); + match decode_worker.worker_type() { + WorkerType::Decode => (), + _ => panic!("Expected Decode worker type"), + } + + // Test regular worker creation + let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string()); + assert_eq!(regular_worker.url(), "http://regular:8080"); + match regular_worker.worker_type() { + WorkerType::Regular => (), + _ => panic!("Expected Regular worker type"), + } } #[test] @@ -230,6 +225,9 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_simulation() { + use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::pd_types::get_hostname; + // Since we can't test the actual inject_bootstrap_fields function here // (it's private in the router module), we'll test the expected behavior @@ -240,15 +238,24 @@ mod test_pd_routing { "temperature": 0.7 }); + // Create a prefill worker to simulate injection + let prefill_worker = + WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000)); + + // Extract bootstrap port from worker type + let bootstrap_port = match prefill_worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + // Simulate what inject_bootstrap_fields would do - let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000)); - single_json["bootstrap_host"] = json!(prefill_info.get_hostname()); - single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port); + single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url())); + single_json["bootstrap_port"] = json!(bootstrap_port); single_json["bootstrap_room"] = json!(12345u64); // Random room ID // Verify bootstrap fields are added correctly assert_eq!(single_json["bootstrap_host"], "prefill1"); - assert_eq!(single_json["bootstrap_port"], 9000); + assert_eq!(single_json["bootstrap_port"], json!(Some(9000))); assert!(single_json["bootstrap_room"].is_u64()); assert_eq!(single_json["temperature"], 0.7); // Original field preserved @@ -259,8 +266,9 @@ mod test_pd_routing { }); let batch_size = 3; - batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); - batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); + let hostname = get_hostname(prefill_worker.url()); + batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]); + batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); // Verify batch bootstrap fields @@ -306,7 +314,9 @@ mod test_pd_routing { } #[test] - fn test_engine_info_hostname_extraction() { + fn test_hostname_extraction() { + use sglang_router_rs::pd_types::get_hostname; + // Test various URL formats let test_cases = vec![ ("http://localhost:8080", "localhost"), @@ -318,8 +328,7 @@ mod test_pd_routing { ]; for (url, expected_hostname) in test_cases { - let engine = EngineInfo::new_prefill(url.to_string(), None); - assert_eq!(engine.get_hostname(), expected_hostname); + assert_eq!(get_hostname(url), expected_hostname); } } @@ -652,6 +661,9 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_with_benchmark_requests() { + use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::pd_types::get_hostname; + // Test bootstrap injection with actual benchmark request patterns let mut benchmark_request = json!({ "input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16 @@ -664,12 +676,20 @@ mod test_pd_routing { "stream": true }); - // Simulate bootstrap injection - let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); - let batch_size = 16; + // Create a prefill worker to simulate injection + let prefill_worker = + WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); - benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); - benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); + // Extract bootstrap port from worker type + let bootstrap_port = match prefill_worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + let batch_size = 16; + let hostname = get_hostname(prefill_worker.url()); + + benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]); + benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); benchmark_request["bootstrap_room"] = json!((0..batch_size).map(|_| 12345u64).collect::>()); @@ -770,6 +790,9 @@ mod test_pd_routing { #[test] fn test_large_batch_bootstrap_injection() { + use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::pd_types::get_hostname; + // Test bootstrap injection performance with very large batches // This simulates the bench_one_batch_server.py scenario let large_batch_sizes = vec![1024, 4096, 8192]; @@ -787,14 +810,19 @@ mod test_pd_routing { "stream": true }); - // Simulate bootstrap injection - let prefill_info = - EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); + // Create a prefill worker to simulate injection + let prefill_worker = + WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); - large_batch_request["bootstrap_host"] = - json!(vec![prefill_info.get_hostname(); batch_size]); - large_batch_request["bootstrap_port"] = - json!(vec![prefill_info.bootstrap_port; batch_size]); + // Extract bootstrap port from worker type + let bootstrap_port = match prefill_worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + let hostname = get_hostname(prefill_worker.url()); + + large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]); + large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); large_batch_request["bootstrap_room"] = json!((0..batch_size) .map(|_| rand::thread_rng().gen::()) .collect::>());