diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 28cd5d11f..384e3666d 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -33,6 +33,7 @@ class RouterArgs: # Routing policy policy: str = "cache_aware" + worker_startup_timeout_secs: int = 300 cache_threshold: float = 0.5 balance_abs_threshold: int = 32 balance_rel_threshold: float = 1.0001 @@ -87,6 +88,12 @@ class RouterArgs: choices=["random", "round_robin", "cache_aware"], help="Load balancing policy to use", ) + parser.add_argument( + f"--{prefix}worker-startup-timeout-secs", + type=int, + default=RouterArgs.worker_startup_timeout_secs, + help="Timeout in seconds for worker startup", + ) parser.add_argument( f"--{prefix}cache-threshold", type=float, @@ -147,6 +154,9 @@ class RouterArgs: host=args.host, port=args.port, policy=getattr(args, f"{prefix}policy"), + worker_startup_timeout_secs=getattr( + args, f"{prefix}worker_startup_timeout_secs" + ), cache_threshold=getattr(args, f"{prefix}cache_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), @@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router = Router( worker_urls=router_args.worker_urls, - policy=policy_from_str(router_args.policy), host=router_args.host, port=router_args.port, + policy=policy_from_str(router_args.policy), + worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, cache_threshold=router_args.cache_threshold, balance_abs_threshold=router_args.balance_abs_threshold, balance_rel_threshold=router_args.balance_rel_threshold, @@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: except Exception as e: logger.error(f"Error starting router: {e}") - return None + raise e class CustomHelpFormatter( @@ -239,10 +250,7 @@ Examples: def main() -> None: router_args = parse_router_args(sys.argv[1:]) - router = launch_router(router_args) - - if router is None: - sys.exit(1) + launch_router(router_args) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py index 93bc2345d..74353c21e 100644 --- a/sgl-router/py_src/sglang_router/launch_server.py +++ b/sgl-router/py_src/sglang_router/launch_server.py @@ -68,7 +68,7 @@ def run_server(server_args, dp_rank): # create new process group os.setpgrp() - setproctitle(f"sglang::server") + setproctitle("sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]: def cleanup_processes(processes: List[mp.Process]): for process in processes: - logger.info(f"Terminating process {process.pid}") - process.terminate() - logger.info("All processes terminated") + logger.info(f"Terminating process group {process.pid}") + try: + os.killpg(process.pid, signal.SIGTERM) + except ProcessLookupError: + # Process group may already be terminated + pass + + # Wait for processes to terminate + for process in processes: + process.join(timeout=5) + if process.is_alive(): + logger.warning( + f"Process {process.pid} did not terminate gracefully, forcing kill" + ) + try: + os.killpg(process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + + logger.info("All process groups terminated") def main(): @@ -173,7 +190,12 @@ def main(): ] # Start the router - router = launch_router(router_args) + try: + launch_router(router_args) + except Exception as e: + logger.error(f"Failed to start router: {e}") + cleanup_processes(server_processes) + sys.exit(1) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 5ce21c3d7..1665f8a67 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -17,6 +17,7 @@ class Router: - PolicyType.CacheAware: Distribute requests based on cache state and load balance host: Host address to bind the router server. Default: '127.0.0.1' port: Port number to bind the router server. Default: 3001 + worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5 @@ -37,6 +38,7 @@ class Router: policy: PolicyType = PolicyType.RoundRobin, host: str = "127.0.0.1", port: int = 3001, + worker_startup_timeout_secs: int = 300, cache_threshold: float = 0.50, balance_abs_threshold: int = 32, balance_rel_threshold: float = 1.0001, @@ -50,6 +52,7 @@ class Router: policy=policy, host=host, port=port, + worker_startup_timeout_secs=worker_startup_timeout_secs, cache_threshold=cache_threshold, balance_abs_threshold=balance_abs_threshold, balance_rel_threshold=balance_rel_threshold, diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 94912f694..15549cae7 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -28,6 +28,7 @@ class TestLaunchRouter(unittest.TestCase): host="127.0.0.1", port=30000, policy="cache_aware", + worker_startup_timeout_secs=600, cache_threshold=0.5, balance_abs_threshold=32, balance_rel_threshold=1.0001, diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2d8cf4c0c..8355f1352 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -17,6 +17,7 @@ struct Router { port: u16, worker_urls: Vec, policy: PolicyType, + worker_startup_timeout_secs: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -34,6 +35,7 @@ impl Router { policy = PolicyType::RoundRobin, host = String::from("127.0.0.1"), port = 3001, + worker_startup_timeout_secs = 300, cache_threshold = 0.50, balance_abs_threshold = 32, balance_rel_threshold = 1.0001, @@ -47,6 +49,7 @@ impl Router { policy: PolicyType, host: String, port: u16, + worker_startup_timeout_secs: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -60,6 +63,7 @@ impl Router { port, worker_urls, policy, + worker_startup_timeout_secs, cache_threshold, balance_abs_threshold, balance_rel_threshold, @@ -72,9 +76,14 @@ impl Router { fn start(&self) -> PyResult<()> { let policy_config = match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, + PolicyType::Random => router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + }, + PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + }, PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + timeout_secs: self.worker_startup_timeout_secs, cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, @@ -93,10 +102,9 @@ impl Router { max_payload_size: self.max_payload_size, }) .await - .unwrap(); - }); - - Ok(()) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + }) } } diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 08f6cdefa..6ea791685 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info, warn}; +use log::{debug, error, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; @@ -17,9 +17,11 @@ pub enum Router { RoundRobin { worker_urls: Arc>>, current_index: AtomicUsize, + timeout_secs: u64, }, Random { worker_urls: Arc>>, + timeout_secs: u64, }, CacheAware { /* @@ -89,36 +91,51 @@ pub enum Router { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, + timeout_secs: u64, _eviction_thread: Option>, }, } #[derive(Debug, Clone)] pub enum PolicyConfig { - RandomConfig, - RoundRobinConfig, + RandomConfig { + timeout_secs: u64, + }, + RoundRobinConfig { + timeout_secs: u64, + }, CacheAwareConfig { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + timeout_secs: u64, }, } impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Get timeout from policy config + let timeout_secs = match &policy_config { + PolicyConfig::RandomConfig { timeout_secs } => *timeout_secs, + PolicyConfig::RoundRobinConfig { timeout_secs } => *timeout_secs, + PolicyConfig::CacheAwareConfig { timeout_secs, .. } => *timeout_secs, + }; + // Wait until all workers are healthy - Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, 10)?; // Create router based on policy... Ok(match policy_config { - PolicyConfig::RandomConfig => Router::Random { + PolicyConfig::RandomConfig { timeout_secs } => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), + timeout_secs, }, - PolicyConfig::RoundRobinConfig => Router::RoundRobin { + PolicyConfig::RoundRobinConfig { timeout_secs } => Router::RoundRobin { worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), + timeout_secs, }, PolicyConfig::CacheAwareConfig { cache_threshold, @@ -126,6 +143,7 @@ impl Router { balance_rel_threshold, eviction_interval_secs, max_tree_size, + timeout_secs, } => { let mut running_queue = HashMap::new(); for url in &worker_urls { @@ -176,6 +194,7 @@ impl Router { cache_threshold, balance_abs_threshold, balance_rel_threshold, + timeout_secs, _eviction_thread: Some(eviction_thread), } } @@ -192,6 +211,10 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers to become healthy", + timeout_secs + ); return Err(format!( "Timeout {}s waiting for workers to become healthy", timeout_secs @@ -238,7 +261,7 @@ impl Router { fn select_first_worker(&self) -> Result { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { if worker_urls.read().unwrap().is_empty() { Err("No workers are available".to_string()) @@ -349,6 +372,7 @@ impl Router { Router::RoundRobin { worker_urls, current_index, + .. } => { let idx = current_index .fetch_update( @@ -360,7 +384,7 @@ impl Router { worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => worker_urls.read().unwrap() + Router::Random { worker_urls, .. } => worker_urls.read().unwrap() [rand::random::() % worker_urls.read().unwrap().len()] .clone(), @@ -571,13 +595,21 @@ impl Router { pub async fn add_worker(&self, worker_url: &str) -> Result { let interval_secs = 10; // check every 10 seconds - let timeout_secs = 300; // 5 minutes + let timeout_secs = match self { + Router::Random { timeout_secs, .. } => *timeout_secs, + Router::RoundRobin { timeout_secs, .. } => *timeout_secs, + Router::CacheAware { timeout_secs, .. } => *timeout_secs, + }; let start_time = std::time::Instant::now(); let client = reqwest::Client::new(); loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy", + timeout_secs, worker_url + ); return Err(format!( "Timeout {}s waiting for worker {} to become healthy", timeout_secs, worker_url @@ -589,7 +621,7 @@ impl Router { if res.status().is_success() { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { info!("Worker {} health check passed", worker_url); let mut urls = worker_urls.write().unwrap(); @@ -663,7 +695,7 @@ impl Router { pub fn remove_worker(&self, worker_url: &str) { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); if let Some(index) = urls.iter().position(|url| url == &worker_url) { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 09878f07f..e3587389e 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -18,14 +18,10 @@ impl AppState { worker_urls: Vec, client: reqwest::Client, policy_config: PolicyConfig, - ) -> Self { + ) -> Result { // Create router based on policy - let router = match Router::new(worker_urls, policy_config) { - Ok(router) => router, - Err(error) => panic!("Failed to create router: {}", error), - }; - - Self { router, client } + let router = Router::new(worker_urls, policy_config)?; + Ok(Self { router, client }) } } @@ -131,6 +127,7 @@ pub struct ServerConfig { } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { + // Initialize logger Builder::new() .format(|buf, record| { use chrono::Local; @@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); + info!("🚧 Initializing router on {}:{}", config.host, config.port); + info!("🚧 Initializing workers on {:?}", config.worker_urls); + info!("🚧 Policy Config: {:?}", config.policy_config); + info!( + "🚧 Max payload size: {} MB", + config.max_payload_size / (1024 * 1024) + ); + let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); - let app_state = web::Data::new(AppState::new( - config.worker_urls.clone(), - client, - config.policy_config.clone(), - )); - - info!("✅ Starting router on {}:{}", config.host, config.port); - info!("✅ Serving Worker URLs: {:?}", config.worker_urls); - info!("✅ Policy Config: {:?}", config.policy_config); - info!( - "✅ Max payload size: {} MB", - config.max_payload_size / (1024 * 1024) + let app_state = web::Data::new( + AppState::new( + config.worker_urls.clone(), + client, + config.policy_config.clone(), + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, ); + info!("✅ Serving router on {}:{}", config.host, config.port); + info!("✅ Serving workers on {:?}", config.worker_urls); + HttpServer::new(move || { App::new() .app_data(app_state.clone())