From 5ee777c98ff558d1acc089e162f22fb9cde1b3e0 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 6 Oct 2025 11:16:59 -0400 Subject: [PATCH] [router] add ipv6 support across all components (#11219) --- sgl-router/benches/request_processing.rs | 11 ++----- sgl-router/py_src/sglang_router/router.py | 2 +- .../py_src/sglang_router/router_args.py | 10 +++---- sgl-router/py_test/unit/test_arg_parser.py | 2 +- sgl-router/src/config/types.rs | 10 +++---- sgl-router/src/core/worker.rs | 27 ++++++++++------- sgl-router/src/core/worker_builder.rs | 29 +++++++++++++------ .../src/grpc_client/sglang_scheduler.rs | 19 ++++++++++-- sgl-router/src/lib.rs | 2 +- sgl-router/src/main.rs | 4 +-- sgl-router/src/routers/http/pd_router.rs | 14 +++------ sgl-router/src/routers/http/pd_types.rs | 8 ----- sgl-router/src/server.rs | 9 ++++-- sgl-router/tests/test_pd_routing.rs | 25 +++------------- 14 files changed, 84 insertions(+), 88 deletions(-) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 2a1163deb..ea3160218 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -7,9 +7,7 @@ use sglang_router_rs::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::routers::http::pd_types::{ - generate_room_id, get_hostname, RequestWithBootstrap, -}; +use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; fn create_test_worker() -> BasicWorker { BasicWorkerBuilder::new("http://test-server:8000") @@ -21,11 +19,8 @@ fn create_test_worker() -> BasicWorker { // Helper function to get bootstrap info from worker fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option) { - let hostname = get_hostname(worker.url()); - let bootstrap_port = match worker.worker_type() { - WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, - }; + let hostname = worker.bootstrap_host().to_string(); + let bootstrap_port = worker.bootstrap_port(); (hostname, bootstrap_port) } diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 4a9db82dc..74f2cef73 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -30,7 +30,7 @@ class Router: - PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.CacheAware: Distribute requests based on cache state and load balance - PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only) - host: Host address to bind the router server. Default: '127.0.0.1' + host: Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces. Default: '0.0.0.0' port: Port number to bind the router server. Default: 3001 worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10 diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py index d26afbadc..5aaa5d772 100644 --- a/sgl-router/py_src/sglang_router/router_args.py +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) class RouterArgs: # Worker configuration worker_urls: List[str] = dataclasses.field(default_factory=list) - host: str = "127.0.0.1" + host: str = "0.0.0.0" port: int = 30000 # PD-specific configuration @@ -109,7 +109,7 @@ class RouterArgs: "--host", type=str, default=RouterArgs.host, - help="Host address to bind the router server", + help="Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces", ) parser.add_argument( "--port", @@ -123,7 +123,7 @@ class RouterArgs: type=str, nargs="*", default=[], - help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", + help="List of worker URLs. Supports IPv4 and IPv6 addresses (use brackets for IPv6, e.g., http://[::1]:8000 http://192.168.1.1:8000)", ) # Routing policy configuration @@ -299,8 +299,8 @@ class RouterArgs: parser.add_argument( f"--{prefix}prometheus-host", type=str, - default="127.0.0.1", - help="Host address to bind the Prometheus metrics server", + default="0.0.0.0", + help="Host address to bind the Prometheus metrics server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces", ) parser.add_argument( f"--{prefix}request-id-headers", diff --git a/sgl-router/py_test/unit/test_arg_parser.py b/sgl-router/py_test/unit/test_arg_parser.py index 04d8a112d..0da764ddf 100644 --- a/sgl-router/py_test/unit/test_arg_parser.py +++ b/sgl-router/py_test/unit/test_arg_parser.py @@ -22,7 +22,7 @@ class TestRouterArgs: args = RouterArgs() # Test basic defaults - assert args.host == "127.0.0.1" + assert args.host == "0.0.0.0" assert args.port == 30000 assert args.policy == "cache_aware" assert args.worker_urls == [] diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 3ecd9a7d3..0c3b69537 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -407,7 +407,7 @@ impl Default for MetricsConfig { fn default() -> Self { Self { port: 29000, - host: "127.0.0.1".to_string(), + host: "0.0.0.0".to_string(), } } } @@ -419,7 +419,7 @@ impl Default for RouterConfig { worker_urls: vec![], }, policy: PolicyConfig::Random, - host: "127.0.0.1".to_string(), + host: "0.0.0.0".to_string(), port: 3001, max_payload_size: 536_870_912, // 512MB request_timeout_secs: 1800, // 30 minutes @@ -522,7 +522,7 @@ mod tests { matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty()) ); assert!(matches!(config.policy, PolicyConfig::Random)); - assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.host, "0.0.0.0"); assert_eq!(config.port, 3001); assert_eq!(config.max_payload_size, 536_870_912); assert_eq!(config.request_timeout_secs, 1800); @@ -553,7 +553,7 @@ mod tests { } assert!(matches!(config.policy, PolicyConfig::RoundRobin)); - assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.host, "0.0.0.0"); assert_eq!(config.port, 3001); } @@ -800,7 +800,7 @@ mod tests { let config = MetricsConfig::default(); assert_eq!(config.port, 29000); - assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.host, "0.0.0.0"); } #[test] diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 570244dc7..3f5f2bb76 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -356,17 +356,22 @@ impl fmt::Debug for BasicWorker { impl BasicWorker { pub fn normalised_url(&self) -> WorkerResult<&str> { if self.url().contains("@") { - let parts: Vec<&str> = self.url().split('@').collect(); - if parts.len() != 2 { - return Err(WorkerError::InvalidUrl { - url: self.url().to_string(), - }); - } - match parts[1].parse::() { - Ok(_) => Ok(parts[0]), - Err(_) => Err(WorkerError::InvalidUrl { - url: self.url().to_string(), - }), + // Use rfind to split from the right, handling IPv6 addresses with brackets + // e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0" + if let Some(at_pos) = self.url().rfind('@') { + let base_url = &self.url()[..at_pos]; + let rank_str = &self.url()[at_pos + 1..]; + + // Validate that the rank part is actually a number + match rank_str.parse::() { + Ok(_) => Ok(base_url), + Err(_) => { + // The '@' is not a DP rank separator, return full URL + Ok(self.url()) + } + } + } else { + Ok(self.url()) } } else { Ok(self.url()) diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 69a4047b2..77863263b 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -96,22 +96,33 @@ impl BasicWorkerBuilder { /// Build the BasicWorker instance pub fn build(self) -> BasicWorker { - use std::borrow::Cow; use std::sync::{ atomic::{AtomicBool, AtomicUsize}, Arc, }; use tokio::sync::{Mutex, RwLock}; - let url_to_parse = if self.url.contains("://") { - Cow::from(&self.url) - } else { - Cow::from(format!("http://{}", self.url)) - }; - - let bootstrap_host = match url::Url::parse(&url_to_parse) { + let bootstrap_host = match url::Url::parse(&self.url) { Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(), - Err(_) => "localhost".to_string(), + Err(_) if !self.url.contains("://") => { + match url::Url::parse(&format!("http://{}", self.url)) { + Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(), + Err(_) => { + tracing::warn!( + "Failed to parse URL '{}', defaulting to localhost", + self.url + ); + "localhost".to_string() + } + } + } + Err(_) => { + tracing::warn!( + "Failed to parse URL '{}', defaulting to localhost", + self.url + ); + "localhost".to_string() + } }; let bootstrap_port = match self.worker_type { diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 845c217be..1097a18f1 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use std::time::Duration; use tonic::{transport::Channel, Request}; -use tracing::debug; +use tracing::{debug, warn}; use crate::protocols::spec::{ ChatCompletionRequest, GenerateRequest, ResponseFormat, @@ -27,9 +27,22 @@ impl SglangSchedulerClient { pub async fn connect(endpoint: &str) -> Result> { debug!("Connecting to SGLang scheduler at {}", endpoint); - // Convert grpc:// to http:// for tonic + // Convert grpc:// to http:// for tonic, preserving IPv6 bracket notation let http_endpoint = if endpoint.starts_with("grpc://") { - endpoint.replace("grpc://", "http://") + // Use proper URL parsing to preserve IPv6 brackets + match url::Url::parse(endpoint) { + Ok(mut parsed) => { + let _ = parsed.set_scheme("http"); + parsed.to_string() + } + Err(_) => { + warn!( + "Failed to parse gRPC endpoint '{}', using simple string replacement", + endpoint + ); + endpoint.replace("grpc://", "http://") + } + } } else { endpoint.to_string() }; diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 43225d4cc..e5d254991 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -226,7 +226,7 @@ impl Router { #[pyo3(signature = ( worker_urls, policy = PolicyType::RoundRobin, - host = String::from("127.0.0.1"), + host = String::from("0.0.0.0"), port = 3001, worker_startup_timeout_secs = 600, worker_startup_check_interval = 30, diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 888502b9e..57d7fc81d 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -99,7 +99,7 @@ Examples: "#)] struct CliArgs { - #[arg(long, default_value = "127.0.0.1")] + #[arg(long, default_value = "0.0.0.0")] host: String, #[arg(long, default_value_t = 30000)] @@ -183,7 +183,7 @@ struct CliArgs { #[arg(long, default_value_t = 29000)] prometheus_port: u16, - #[arg(long, default_value = "127.0.0.1")] + #[arg(long, default_value = "0.0.0.0")] prometheus_host: String, #[arg(long, num_args = 0..)] diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 5d0a9a801..e18e856f3 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -186,12 +186,6 @@ impl PDRouter { prefill_worker: &dyn Worker, batch_size: Option, ) -> Result { - let bootstrap_port = match prefill_worker.worker_type() { - crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, - }; - let hostname = super::pd_types::get_hostname(prefill_worker.url()); - let obj = original .as_object_mut() .ok_or_else(|| "Request must be a JSON object".to_string())?; @@ -201,8 +195,8 @@ impl PDRouter { let mut ports = Vec::with_capacity(n); let mut rooms = Vec::with_capacity(n); for _ in 0..n { - hosts.push(hostname.clone()); - ports.push(bootstrap_port); + hosts.push(prefill_worker.bootstrap_host()); + ports.push(prefill_worker.bootstrap_port()); rooms.push(super::pd_types::generate_room_id()); } obj.insert( @@ -228,11 +222,11 @@ impl PDRouter { } else { obj.insert( "bootstrap_host".to_string(), - serde_json::Value::from(hostname), + serde_json::Value::from(prefill_worker.bootstrap_host()), ); obj.insert( "bootstrap_port".to_string(), - match bootstrap_port { + match prefill_worker.bootstrap_port() { Some(v) => serde_json::Value::from(v), None => Value::Null, }, diff --git a/sgl-router/src/routers/http/pd_types.rs b/sgl-router/src/routers/http/pd_types.rs index a2b28a57d..78c93d82e 100644 --- a/sgl-router/src/routers/http/pd_types.rs +++ b/sgl-router/src/routers/http/pd_types.rs @@ -32,14 +32,6 @@ pub fn api_path(url: &str, api_path: &str) -> String { } } -pub fn get_hostname(url: &str) -> String { - // Simple hostname extraction without external dependencies - let url = url - .trim_start_matches("http://") - .trim_start_matches("https://"); - url.split(':').next().unwrap_or("localhost").to_string() -} - use serde::Serialize; // Optimized bootstrap wrapper for single requests diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 4dcc71fb1..e3c9309b3 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -807,9 +807,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box None, }; - single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url())); + single_json["bootstrap_host"] = json!(prefill_worker.bootstrap_host()); single_json["bootstrap_port"] = json!(bootstrap_port); single_json["bootstrap_room"] = json!(12345u64); // Random room ID @@ -301,7 +300,7 @@ mod test_pd_routing { }); let batch_size = 3; - let hostname = get_hostname(prefill_worker.url()); + let hostname = prefill_worker.bootstrap_host(); batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]); batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); @@ -343,22 +342,6 @@ mod test_pd_routing { assert_eq!(parsed["bootstrap_room"], 12345); } - #[test] - fn test_hostname_extraction() { - let test_cases = vec![ - ("http://localhost:8080", "localhost"), - ("http://10.0.0.1:8080", "10.0.0.1"), - ("https://api.example.com:443", "api.example.com"), - ("http://prefill-server", "prefill-server"), - ("http://[::1]:8080", "["), // IPv6 edge case - ("prefill:8080", "prefill"), // No protocol - ]; - - for (url, expected_hostname) in test_cases { - assert_eq!(get_hostname(url), expected_hostname); - } - } - #[test] fn test_pd_request_edge_cases() { let empty_json = json!({}); @@ -644,7 +627,7 @@ mod test_pd_routing { _ => None, }; let batch_size = 16; - let hostname = get_hostname(prefill_worker.url()); + let hostname = prefill_worker.bootstrap_host(); benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]); benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); @@ -769,7 +752,7 @@ mod test_pd_routing { WorkerType::Prefill { bootstrap_port } => bootstrap_port, _ => None, }; - let hostname = get_hostname(prefill_worker.url()); + let hostname = prefill_worker.bootstrap_host(); large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]); large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);