[router] add ipv6 support across all components (#11219)
This commit is contained in:
@@ -7,9 +7,7 @@ use sglang_router_rs::protocols::spec::{
|
|||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::http::pd_types::{
|
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap};
|
||||||
generate_room_id, get_hostname, RequestWithBootstrap,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn create_test_worker() -> BasicWorker {
|
fn create_test_worker() -> BasicWorker {
|
||||||
BasicWorkerBuilder::new("http://test-server:8000")
|
BasicWorkerBuilder::new("http://test-server:8000")
|
||||||
@@ -21,11 +19,8 @@ fn create_test_worker() -> BasicWorker {
|
|||||||
|
|
||||||
// Helper function to get bootstrap info from worker
|
// Helper function to get bootstrap info from worker
|
||||||
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
||||||
let hostname = get_hostname(worker.url());
|
let hostname = worker.bootstrap_host().to_string();
|
||||||
let bootstrap_port = match worker.worker_type() {
|
let bootstrap_port = worker.bootstrap_port();
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
(hostname, bootstrap_port)
|
(hostname, bootstrap_port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class Router:
|
|||||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||||
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
||||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
host: Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces. Default: '0.0.0.0'
|
||||||
port: Port number to bind the router server. Default: 3001
|
port: Port number to bind the router server. Default: 3001
|
||||||
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
||||||
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class RouterArgs:
|
class RouterArgs:
|
||||||
# Worker configuration
|
# Worker configuration
|
||||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||||
host: str = "127.0.0.1"
|
host: str = "0.0.0.0"
|
||||||
port: int = 30000
|
port: int = 30000
|
||||||
|
|
||||||
# PD-specific configuration
|
# PD-specific configuration
|
||||||
@@ -109,7 +109,7 @@ class RouterArgs:
|
|||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default=RouterArgs.host,
|
default=RouterArgs.host,
|
||||||
help="Host address to bind the router server",
|
help="Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
@@ -123,7 +123,7 @@ class RouterArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
nargs="*",
|
nargs="*",
|
||||||
default=[],
|
default=[],
|
||||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
help="List of worker URLs. Supports IPv4 and IPv6 addresses (use brackets for IPv6, e.g., http://[::1]:8000 http://192.168.1.1:8000)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Routing policy configuration
|
# Routing policy configuration
|
||||||
@@ -299,8 +299,8 @@ class RouterArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}prometheus-host",
|
f"--{prefix}prometheus-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="127.0.0.1",
|
default="0.0.0.0",
|
||||||
help="Host address to bind the Prometheus metrics server",
|
help="Host address to bind the Prometheus metrics server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}request-id-headers",
|
f"--{prefix}request-id-headers",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestRouterArgs:
|
|||||||
args = RouterArgs()
|
args = RouterArgs()
|
||||||
|
|
||||||
# Test basic defaults
|
# Test basic defaults
|
||||||
assert args.host == "127.0.0.1"
|
assert args.host == "0.0.0.0"
|
||||||
assert args.port == 30000
|
assert args.port == 30000
|
||||||
assert args.policy == "cache_aware"
|
assert args.policy == "cache_aware"
|
||||||
assert args.worker_urls == []
|
assert args.worker_urls == []
|
||||||
|
|||||||
@@ -407,7 +407,7 @@ impl Default for MetricsConfig {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
port: 29000,
|
port: 29000,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "0.0.0.0".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -419,7 +419,7 @@ impl Default for RouterConfig {
|
|||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
},
|
},
|
||||||
policy: PolicyConfig::Random,
|
policy: PolicyConfig::Random,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "0.0.0.0".to_string(),
|
||||||
port: 3001,
|
port: 3001,
|
||||||
max_payload_size: 536_870_912, // 512MB
|
max_payload_size: 536_870_912, // 512MB
|
||||||
request_timeout_secs: 1800, // 30 minutes
|
request_timeout_secs: 1800, // 30 minutes
|
||||||
@@ -522,7 +522,7 @@ mod tests {
|
|||||||
matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty())
|
matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty())
|
||||||
);
|
);
|
||||||
assert!(matches!(config.policy, PolicyConfig::Random));
|
assert!(matches!(config.policy, PolicyConfig::Random));
|
||||||
assert_eq!(config.host, "127.0.0.1");
|
assert_eq!(config.host, "0.0.0.0");
|
||||||
assert_eq!(config.port, 3001);
|
assert_eq!(config.port, 3001);
|
||||||
assert_eq!(config.max_payload_size, 536_870_912);
|
assert_eq!(config.max_payload_size, 536_870_912);
|
||||||
assert_eq!(config.request_timeout_secs, 1800);
|
assert_eq!(config.request_timeout_secs, 1800);
|
||||||
@@ -553,7 +553,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
|
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
|
||||||
assert_eq!(config.host, "127.0.0.1");
|
assert_eq!(config.host, "0.0.0.0");
|
||||||
assert_eq!(config.port, 3001);
|
assert_eq!(config.port, 3001);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -800,7 +800,7 @@ mod tests {
|
|||||||
let config = MetricsConfig::default();
|
let config = MetricsConfig::default();
|
||||||
|
|
||||||
assert_eq!(config.port, 29000);
|
assert_eq!(config.port, 29000);
|
||||||
assert_eq!(config.host, "127.0.0.1");
|
assert_eq!(config.host, "0.0.0.0");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -356,17 +356,22 @@ impl fmt::Debug for BasicWorker {
|
|||||||
impl BasicWorker {
|
impl BasicWorker {
|
||||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||||
if self.url().contains("@") {
|
if self.url().contains("@") {
|
||||||
let parts: Vec<&str> = self.url().split('@').collect();
|
// Use rfind to split from the right, handling IPv6 addresses with brackets
|
||||||
if parts.len() != 2 {
|
// e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0"
|
||||||
return Err(WorkerError::InvalidUrl {
|
if let Some(at_pos) = self.url().rfind('@') {
|
||||||
url: self.url().to_string(),
|
let base_url = &self.url()[..at_pos];
|
||||||
});
|
let rank_str = &self.url()[at_pos + 1..];
|
||||||
}
|
|
||||||
match parts[1].parse::<usize>() {
|
// Validate that the rank part is actually a number
|
||||||
Ok(_) => Ok(parts[0]),
|
match rank_str.parse::<usize>() {
|
||||||
Err(_) => Err(WorkerError::InvalidUrl {
|
Ok(_) => Ok(base_url),
|
||||||
url: self.url().to_string(),
|
Err(_) => {
|
||||||
}),
|
// The '@' is not a DP rank separator, return full URL
|
||||||
|
Ok(self.url())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(self.url())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(self.url())
|
Ok(self.url())
|
||||||
|
|||||||
@@ -96,22 +96,33 @@ impl BasicWorkerBuilder {
|
|||||||
|
|
||||||
/// Build the BasicWorker instance
|
/// Build the BasicWorker instance
|
||||||
pub fn build(self) -> BasicWorker {
|
pub fn build(self) -> BasicWorker {
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, AtomicUsize},
|
atomic::{AtomicBool, AtomicUsize},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use tokio::sync::{Mutex, RwLock};
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
|
||||||
let url_to_parse = if self.url.contains("://") {
|
let bootstrap_host = match url::Url::parse(&self.url) {
|
||||||
Cow::from(&self.url)
|
|
||||||
} else {
|
|
||||||
Cow::from(format!("http://{}", self.url))
|
|
||||||
};
|
|
||||||
|
|
||||||
let bootstrap_host = match url::Url::parse(&url_to_parse) {
|
|
||||||
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
||||||
Err(_) => "localhost".to_string(),
|
Err(_) if !self.url.contains("://") => {
|
||||||
|
match url::Url::parse(&format!("http://{}", self.url)) {
|
||||||
|
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to parse URL '{}', defaulting to localhost",
|
||||||
|
self.url
|
||||||
|
);
|
||||||
|
"localhost".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to parse URL '{}', defaulting to localhost",
|
||||||
|
self.url
|
||||||
|
);
|
||||||
|
"localhost".to_string()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let bootstrap_port = match self.worker_type {
|
let bootstrap_port = match self.worker_type {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::{transport::Channel, Request};
|
use tonic::{transport::Channel, Request};
|
||||||
use tracing::debug;
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
||||||
@@ -27,9 +27,22 @@ impl SglangSchedulerClient {
|
|||||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||||
|
|
||||||
// Convert grpc:// to http:// for tonic
|
// Convert grpc:// to http:// for tonic, preserving IPv6 bracket notation
|
||||||
let http_endpoint = if endpoint.starts_with("grpc://") {
|
let http_endpoint = if endpoint.starts_with("grpc://") {
|
||||||
endpoint.replace("grpc://", "http://")
|
// Use proper URL parsing to preserve IPv6 brackets
|
||||||
|
match url::Url::parse(endpoint) {
|
||||||
|
Ok(mut parsed) => {
|
||||||
|
let _ = parsed.set_scheme("http");
|
||||||
|
parsed.to_string()
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(
|
||||||
|
"Failed to parse gRPC endpoint '{}', using simple string replacement",
|
||||||
|
endpoint
|
||||||
|
);
|
||||||
|
endpoint.replace("grpc://", "http://")
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
endpoint.to_string()
|
endpoint.to_string()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ impl Router {
|
|||||||
#[pyo3(signature = (
|
#[pyo3(signature = (
|
||||||
worker_urls,
|
worker_urls,
|
||||||
policy = PolicyType::RoundRobin,
|
policy = PolicyType::RoundRobin,
|
||||||
host = String::from("127.0.0.1"),
|
host = String::from("0.0.0.0"),
|
||||||
port = 3001,
|
port = 3001,
|
||||||
worker_startup_timeout_secs = 600,
|
worker_startup_timeout_secs = 600,
|
||||||
worker_startup_check_interval = 30,
|
worker_startup_check_interval = 30,
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ Examples:
|
|||||||
|
|
||||||
"#)]
|
"#)]
|
||||||
struct CliArgs {
|
struct CliArgs {
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long, default_value = "0.0.0.0")]
|
||||||
host: String,
|
host: String,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 30000)]
|
#[arg(long, default_value_t = 30000)]
|
||||||
@@ -183,7 +183,7 @@ struct CliArgs {
|
|||||||
#[arg(long, default_value_t = 29000)]
|
#[arg(long, default_value_t = 29000)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
|
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long, default_value = "0.0.0.0")]
|
||||||
prometheus_host: String,
|
prometheus_host: String,
|
||||||
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
|
|||||||
@@ -186,12 +186,6 @@ impl PDRouter {
|
|||||||
prefill_worker: &dyn Worker,
|
prefill_worker: &dyn Worker,
|
||||||
batch_size: Option<usize>,
|
batch_size: Option<usize>,
|
||||||
) -> Result<Value, String> {
|
) -> Result<Value, String> {
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
|
||||||
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
let hostname = super::pd_types::get_hostname(prefill_worker.url());
|
|
||||||
|
|
||||||
let obj = original
|
let obj = original
|
||||||
.as_object_mut()
|
.as_object_mut()
|
||||||
.ok_or_else(|| "Request must be a JSON object".to_string())?;
|
.ok_or_else(|| "Request must be a JSON object".to_string())?;
|
||||||
@@ -201,8 +195,8 @@ impl PDRouter {
|
|||||||
let mut ports = Vec::with_capacity(n);
|
let mut ports = Vec::with_capacity(n);
|
||||||
let mut rooms = Vec::with_capacity(n);
|
let mut rooms = Vec::with_capacity(n);
|
||||||
for _ in 0..n {
|
for _ in 0..n {
|
||||||
hosts.push(hostname.clone());
|
hosts.push(prefill_worker.bootstrap_host());
|
||||||
ports.push(bootstrap_port);
|
ports.push(prefill_worker.bootstrap_port());
|
||||||
rooms.push(super::pd_types::generate_room_id());
|
rooms.push(super::pd_types::generate_room_id());
|
||||||
}
|
}
|
||||||
obj.insert(
|
obj.insert(
|
||||||
@@ -228,11 +222,11 @@ impl PDRouter {
|
|||||||
} else {
|
} else {
|
||||||
obj.insert(
|
obj.insert(
|
||||||
"bootstrap_host".to_string(),
|
"bootstrap_host".to_string(),
|
||||||
serde_json::Value::from(hostname),
|
serde_json::Value::from(prefill_worker.bootstrap_host()),
|
||||||
);
|
);
|
||||||
obj.insert(
|
obj.insert(
|
||||||
"bootstrap_port".to_string(),
|
"bootstrap_port".to_string(),
|
||||||
match bootstrap_port {
|
match prefill_worker.bootstrap_port() {
|
||||||
Some(v) => serde_json::Value::from(v),
|
Some(v) => serde_json::Value::from(v),
|
||||||
None => Value::Null,
|
None => Value::Null,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -32,14 +32,6 @@ pub fn api_path(url: &str, api_path: &str) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_hostname(url: &str) -> String {
|
|
||||||
// Simple hostname extraction without external dependencies
|
|
||||||
let url = url
|
|
||||||
.trim_start_matches("http://")
|
|
||||||
.trim_start_matches("https://");
|
|
||||||
url.split(':').next().unwrap_or("localhost").to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
// Optimized bootstrap wrapper for single requests
|
// Optimized bootstrap wrapper for single requests
|
||||||
|
|||||||
@@ -807,9 +807,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
config.router_config.cors_allowed_origins.clone(),
|
config.router_config.cors_allowed_origins.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let addr = format!("{}:{}", config.host, config.port);
|
// TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
let bind_addr = format!("{}:{}", config.host, config.port);
|
||||||
info!("Starting server on {}", addr);
|
info!("Starting server on {}", bind_addr);
|
||||||
|
let listener = TcpListener::bind(&bind_addr)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
|
||||||
serve(listener, app)
|
serve(listener, app)
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ mod test_pd_routing {
|
|||||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||||
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
|
||||||
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||||
use sglang_router_rs::routers::RouterFactory;
|
use sglang_router_rs::routers::RouterFactory;
|
||||||
|
|
||||||
@@ -286,7 +285,7 @@ mod test_pd_routing {
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
single_json["bootstrap_host"] = json!(prefill_worker.bootstrap_host());
|
||||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||||
|
|
||||||
@@ -301,7 +300,7 @@ mod test_pd_routing {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let batch_size = 3;
|
let batch_size = 3;
|
||||||
let hostname = get_hostname(prefill_worker.url());
|
let hostname = prefill_worker.bootstrap_host();
|
||||||
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||||
@@ -343,22 +342,6 @@ mod test_pd_routing {
|
|||||||
assert_eq!(parsed["bootstrap_room"], 12345);
|
assert_eq!(parsed["bootstrap_room"], 12345);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_hostname_extraction() {
|
|
||||||
let test_cases = vec![
|
|
||||||
("http://localhost:8080", "localhost"),
|
|
||||||
("http://10.0.0.1:8080", "10.0.0.1"),
|
|
||||||
("https://api.example.com:443", "api.example.com"),
|
|
||||||
("http://prefill-server", "prefill-server"),
|
|
||||||
("http://[::1]:8080", "["), // IPv6 edge case
|
|
||||||
("prefill:8080", "prefill"), // No protocol
|
|
||||||
];
|
|
||||||
|
|
||||||
for (url, expected_hostname) in test_cases {
|
|
||||||
assert_eq!(get_hostname(url), expected_hostname);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_request_edge_cases() {
|
fn test_pd_request_edge_cases() {
|
||||||
let empty_json = json!({});
|
let empty_json = json!({});
|
||||||
@@ -644,7 +627,7 @@ mod test_pd_routing {
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let batch_size = 16;
|
let batch_size = 16;
|
||||||
let hostname = get_hostname(prefill_worker.url());
|
let hostname = prefill_worker.bootstrap_host();
|
||||||
|
|
||||||
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
@@ -769,7 +752,7 @@ mod test_pd_routing {
|
|||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let hostname = get_hostname(prefill_worker.url());
|
let hostname = prefill_worker.bootstrap_host();
|
||||||
|
|
||||||
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
|
|||||||
Reference in New Issue
Block a user