[router] add ipv6 support across all components (#11219)

This commit is contained in:
Simo Lin
2025-10-06 11:16:59 -04:00
committed by GitHub
parent a4a3d82393
commit 5ee777c98f
14 changed files with 84 additions and 88 deletions

View File

@@ -407,7 +407,7 @@ impl Default for MetricsConfig {
fn default() -> Self {
Self {
port: 29000,
host: "127.0.0.1".to_string(),
host: "0.0.0.0".to_string(),
}
}
}
@@ -419,7 +419,7 @@ impl Default for RouterConfig {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536_870_912, // 512MB
request_timeout_secs: 1800, // 30 minutes
@@ -522,7 +522,7 @@ mod tests {
matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty())
);
assert!(matches!(config.policy, PolicyConfig::Random));
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3001);
assert_eq!(config.max_payload_size, 536_870_912);
assert_eq!(config.request_timeout_secs, 1800);
@@ -553,7 +553,7 @@ mod tests {
}
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3001);
}
@@ -800,7 +800,7 @@ mod tests {
let config = MetricsConfig::default();
assert_eq!(config.port, 29000);
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.host, "0.0.0.0");
}
#[test]

View File

@@ -356,17 +356,22 @@ impl fmt::Debug for BasicWorker {
impl BasicWorker {
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
let parts: Vec<&str> = self.url().split('@').collect();
if parts.len() != 2 {
return Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
});
}
match parts[1].parse::<usize>() {
Ok(_) => Ok(parts[0]),
Err(_) => Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
}),
// Use rfind to split from the right, handling IPv6 addresses with brackets
// e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0"
if let Some(at_pos) = self.url().rfind('@') {
let base_url = &self.url()[..at_pos];
let rank_str = &self.url()[at_pos + 1..];
// Validate that the rank part is actually a number
match rank_str.parse::<usize>() {
Ok(_) => Ok(base_url),
Err(_) => {
// The '@' is not a DP rank separator, return full URL
Ok(self.url())
}
}
} else {
Ok(self.url())
}
} else {
Ok(self.url())

View File

@@ -96,22 +96,33 @@ impl BasicWorkerBuilder {
/// Build the BasicWorker instance
pub fn build(self) -> BasicWorker {
use std::borrow::Cow;
use std::sync::{
atomic::{AtomicBool, AtomicUsize},
Arc,
};
use tokio::sync::{Mutex, RwLock};
let url_to_parse = if self.url.contains("://") {
Cow::from(&self.url)
} else {
Cow::from(format!("http://{}", self.url))
};
let bootstrap_host = match url::Url::parse(&url_to_parse) {
let bootstrap_host = match url::Url::parse(&self.url) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
Err(_) => "localhost".to_string(),
Err(_) if !self.url.contains("://") => {
match url::Url::parse(&format!("http://{}", self.url)) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
Err(_) => {
tracing::warn!(
"Failed to parse URL '{}', defaulting to localhost",
self.url
);
"localhost".to_string()
}
}
}
Err(_) => {
tracing::warn!(
"Failed to parse URL '{}', defaulting to localhost",
self.url
);
"localhost".to_string()
}
};
let bootstrap_port = match self.worker_type {

View File

@@ -1,7 +1,7 @@
use std::convert::TryFrom;
use std::time::Duration;
use tonic::{transport::Channel, Request};
use tracing::debug;
use tracing::{debug, warn};
use crate::protocols::spec::{
ChatCompletionRequest, GenerateRequest, ResponseFormat,
@@ -27,9 +27,22 @@ impl SglangSchedulerClient {
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to SGLang scheduler at {}", endpoint);
// Convert grpc:// to http:// for tonic
// Convert grpc:// to http:// for tonic, preserving IPv6 bracket notation
let http_endpoint = if endpoint.starts_with("grpc://") {
endpoint.replace("grpc://", "http://")
// Use proper URL parsing to preserve IPv6 brackets
match url::Url::parse(endpoint) {
Ok(mut parsed) => {
let _ = parsed.set_scheme("http");
parsed.to_string()
}
Err(_) => {
warn!(
"Failed to parse gRPC endpoint '{}', using simple string replacement",
endpoint
);
endpoint.replace("grpc://", "http://")
}
}
} else {
endpoint.to_string()
};

View File

@@ -226,7 +226,7 @@ impl Router {
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
host = String::from("0.0.0.0"),
port = 3001,
worker_startup_timeout_secs = 600,
worker_startup_check_interval = 30,

View File

@@ -99,7 +99,7 @@ Examples:
"#)]
struct CliArgs {
#[arg(long, default_value = "127.0.0.1")]
#[arg(long, default_value = "0.0.0.0")]
host: String,
#[arg(long, default_value_t = 30000)]
@@ -183,7 +183,7 @@ struct CliArgs {
#[arg(long, default_value_t = 29000)]
prometheus_port: u16,
#[arg(long, default_value = "127.0.0.1")]
#[arg(long, default_value = "0.0.0.0")]
prometheus_host: String,
#[arg(long, num_args = 0..)]

View File

@@ -186,12 +186,6 @@ impl PDRouter {
prefill_worker: &dyn Worker,
batch_size: Option<usize>,
) -> Result<Value, String> {
let bootstrap_port = match prefill_worker.worker_type() {
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = super::pd_types::get_hostname(prefill_worker.url());
let obj = original
.as_object_mut()
.ok_or_else(|| "Request must be a JSON object".to_string())?;
@@ -201,8 +195,8 @@ impl PDRouter {
let mut ports = Vec::with_capacity(n);
let mut rooms = Vec::with_capacity(n);
for _ in 0..n {
hosts.push(hostname.clone());
ports.push(bootstrap_port);
hosts.push(prefill_worker.bootstrap_host());
ports.push(prefill_worker.bootstrap_port());
rooms.push(super::pd_types::generate_room_id());
}
obj.insert(
@@ -228,11 +222,11 @@ impl PDRouter {
} else {
obj.insert(
"bootstrap_host".to_string(),
serde_json::Value::from(hostname),
serde_json::Value::from(prefill_worker.bootstrap_host()),
);
obj.insert(
"bootstrap_port".to_string(),
match bootstrap_port {
match prefill_worker.bootstrap_port() {
Some(v) => serde_json::Value::from(v),
None => Value::Null,
},

View File

@@ -32,14 +32,6 @@ pub fn api_path(url: &str, api_path: &str) -> String {
}
}
pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies
let url = url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
use serde::Serialize;
// Optimized bootstrap wrapper for single requests

View File

@@ -807,9 +807,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.cors_allowed_origins.clone(),
);
let addr = format!("{}:{}", config.host, config.port);
let listener = TcpListener::bind(&addr).await?;
info!("Starting server on {}", addr);
// TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
let bind_addr = format!("{}:{}", config.host, config.port);
info!("Starting server on {}", bind_addr);
let listener = TcpListener::bind(&bind_addr)
.await
.map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await