[router] Add Rust Binary Entrypoint for SGLang Router (#9089)
This commit is contained in:
@@ -9,7 +9,12 @@ name = "sglang_router_rs"
|
|||||||
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
|
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
|
||||||
crate-type = ["cdylib", "rlib"]
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "sglang-router"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
|
axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
|
||||||
tower = { version = "0.5", features = ["full"] }
|
tower = { version = "0.5", features = ["full"] }
|
||||||
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
|
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
|
||||||
|
|||||||
@@ -56,7 +56,21 @@ pip install -e .
|
|||||||
cargo build
|
cargo build
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launch Router with Worker URLs in regular mode
|
#### Using the Rust Binary Directly (Alternative to Python)
|
||||||
|
```bash
|
||||||
|
# Build the Rust binary
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# Launch router with worker URLs in regular mode
|
||||||
|
./target/release/sglang-router \
|
||||||
|
--worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
|
||||||
|
# Or use cargo run
|
||||||
|
cargo run --release -- \
|
||||||
|
--worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Launch Router with Python (Original Method)
|
||||||
```bash
|
```bash
|
||||||
# Launch router with worker URLs
|
# Launch router with worker URLs
|
||||||
python -m sglang_router.launch_router \
|
python -m sglang_router.launch_router \
|
||||||
@@ -68,7 +82,22 @@ python -m sglang_router.launch_router \
|
|||||||
# Note that the prefill and decode URLs must be provided in the following format:
|
# Note that the prefill and decode URLs must be provided in the following format:
|
||||||
# http://<ip>:<port> for decode nodes
|
# http://<ip>:<port> for decode nodes
|
||||||
# http://<ip>:<port> bootstrap-port for prefill nodes, where bootstrap-port is optional
|
# http://<ip>:<port> bootstrap-port for prefill nodes, where bootstrap-port is optional
|
||||||
# Launch router with worker URLs
|
|
||||||
|
# Using Rust binary directly
|
||||||
|
./target/release/sglang-router \
|
||||||
|
--pd-disaggregation \
|
||||||
|
--policy cache_aware \
|
||||||
|
--prefill http://127.0.0.1:30001 9001 \
|
||||||
|
--prefill http://127.0.0.2:30002 9002 \
|
||||||
|
--prefill http://127.0.0.3:30003 9003 \
|
||||||
|
--prefill http://127.0.0.4:30004 9004 \
|
||||||
|
--decode http://127.0.0.5:30005 \
|
||||||
|
--decode http://127.0.0.6:30006 \
|
||||||
|
--decode http://127.0.0.7:30007 \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8080
|
||||||
|
|
||||||
|
# Or using Python launcher
|
||||||
python -m sglang_router.launch_router \
|
python -m sglang_router.launch_router \
|
||||||
--pd-disaggregation \
|
--pd-disaggregation \
|
||||||
--policy cache_aware \
|
--policy cache_aware \
|
||||||
|
|||||||
490
sgl-router/src/main.rs
Normal file
490
sgl-router/src/main.rs
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
use clap::{ArgAction, Parser};
|
||||||
|
use sglang_router_rs::config::{
|
||||||
|
CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, MetricsConfig, PolicyConfig,
|
||||||
|
RetryConfig, RouterConfig, RoutingMode,
|
||||||
|
};
|
||||||
|
use sglang_router_rs::metrics::PrometheusConfig;
|
||||||
|
use sglang_router_rs::server::{self, ServerConfig};
|
||||||
|
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// Helper function to parse prefill arguments from command line
|
||||||
|
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
let mut prefill_entries = Vec::new();
|
||||||
|
let mut i = 0;
|
||||||
|
|
||||||
|
while i < args.len() {
|
||||||
|
if args[i] == "--prefill" && i + 1 < args.len() {
|
||||||
|
let url = args[i + 1].clone();
|
||||||
|
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
|
||||||
|
// Check if next arg is a port number
|
||||||
|
if let Ok(port) = args[i + 2].parse::<u16>() {
|
||||||
|
i += 1; // Skip the port argument
|
||||||
|
Some(port)
|
||||||
|
} else if args[i + 2].to_lowercase() == "none" {
|
||||||
|
i += 1; // Skip the "none" argument
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
prefill_entries.push((url, bootstrap_port));
|
||||||
|
i += 2; // Skip --prefill and URL
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill_entries
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "sglang-router")]
|
||||||
|
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
|
||||||
|
#[command(long_about = r#"
|
||||||
|
SGLang Router - High-performance request distribution across worker nodes
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
This launcher enables starting a router with individual worker instances. It is useful for
|
||||||
|
multi-node setups or when you want to start workers and router separately.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Regular mode
|
||||||
|
sglang-router --worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
|
||||||
|
# PD disaggregated mode with same policy for both
|
||||||
|
sglang-router --pd-disaggregation \
|
||||||
|
--prefill http://127.0.0.1:30001 9001 \
|
||||||
|
--prefill http://127.0.0.2:30002 9002 \
|
||||||
|
--decode http://127.0.0.3:30003 \
|
||||||
|
--decode http://127.0.0.4:30004 \
|
||||||
|
--policy cache_aware
|
||||||
|
|
||||||
|
# PD mode with different policies for prefill and decode
|
||||||
|
sglang-router --pd-disaggregation \
|
||||||
|
--prefill http://127.0.0.1:30001 9001 \
|
||||||
|
--prefill http://127.0.0.2:30002 \
|
||||||
|
--decode http://127.0.0.3:30003 \
|
||||||
|
--decode http://127.0.0.4:30004 \
|
||||||
|
--prefill-policy cache_aware --decode-policy power_of_two
|
||||||
|
"#)]
|
||||||
|
struct CliArgs {
|
||||||
|
/// Host address to bind the router server
|
||||||
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
|
host: String,
|
||||||
|
|
||||||
|
/// Port number to bind the router server
|
||||||
|
#[arg(long, default_value_t = 30000)]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
|
||||||
|
/// Load balancing policy to use
|
||||||
|
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
|
policy: String,
|
||||||
|
|
||||||
|
/// Enable PD (Prefill-Decode) disaggregated mode
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
pd_disaggregation: bool,
|
||||||
|
|
||||||
|
/// Decode server URL (can be specified multiple times)
|
||||||
|
#[arg(long, action = ArgAction::Append)]
|
||||||
|
decode: Vec<String>,
|
||||||
|
|
||||||
|
/// Specific policy for prefill nodes in PD mode
|
||||||
|
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
|
prefill_policy: Option<String>,
|
||||||
|
|
||||||
|
/// Specific policy for decode nodes in PD mode
|
||||||
|
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
|
decode_policy: Option<String>,
|
||||||
|
|
||||||
|
/// Timeout in seconds for worker startup
|
||||||
|
#[arg(long, default_value_t = 300)]
|
||||||
|
worker_startup_timeout_secs: u64,
|
||||||
|
|
||||||
|
/// Interval in seconds between checks for worker startup
|
||||||
|
#[arg(long, default_value_t = 10)]
|
||||||
|
worker_startup_check_interval: u64,
|
||||||
|
|
||||||
|
/// Cache threshold (0.0-1.0) for cache-aware routing
|
||||||
|
#[arg(long, default_value_t = 0.5)]
|
||||||
|
cache_threshold: f32,
|
||||||
|
|
||||||
|
/// Absolute threshold for load balancing
|
||||||
|
#[arg(long, default_value_t = 32)]
|
||||||
|
balance_abs_threshold: usize,
|
||||||
|
|
||||||
|
/// Relative threshold for load balancing
|
||||||
|
#[arg(long, default_value_t = 1.0001)]
|
||||||
|
balance_rel_threshold: f32,
|
||||||
|
|
||||||
|
/// Interval in seconds between cache eviction operations
|
||||||
|
#[arg(long, default_value_t = 60)]
|
||||||
|
eviction_interval: u64,
|
||||||
|
|
||||||
|
/// Maximum size of the approximation tree for cache-aware routing
|
||||||
|
#[arg(long, default_value_t = 16777216)] // 2^24
|
||||||
|
max_tree_size: usize,
|
||||||
|
|
||||||
|
/// Maximum payload size in bytes
|
||||||
|
#[arg(long, default_value_t = 268435456)] // 256MB
|
||||||
|
max_payload_size: usize,
|
||||||
|
|
||||||
|
/// Enable data parallelism aware schedule
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
dp_aware: bool,
|
||||||
|
|
||||||
|
/// API key for worker authorization
|
||||||
|
#[arg(long)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
|
||||||
|
/// Directory to store log files
|
||||||
|
#[arg(long)]
|
||||||
|
log_dir: Option<String>,
|
||||||
|
|
||||||
|
/// Set the logging level
|
||||||
|
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
|
||||||
|
log_level: String,
|
||||||
|
|
||||||
|
/// Enable Kubernetes service discovery
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
service_discovery: bool,
|
||||||
|
|
||||||
|
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
selector: Vec<String>,
|
||||||
|
|
||||||
|
/// Port to use for discovered worker pods
|
||||||
|
#[arg(long, default_value_t = 80)]
|
||||||
|
service_discovery_port: u16,
|
||||||
|
|
||||||
|
/// Kubernetes namespace to watch for pods
|
||||||
|
#[arg(long)]
|
||||||
|
service_discovery_namespace: Option<String>,
|
||||||
|
|
||||||
|
/// Label selector for prefill server pods in PD mode
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
prefill_selector: Vec<String>,
|
||||||
|
|
||||||
|
/// Label selector for decode server pods in PD mode
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
decode_selector: Vec<String>,
|
||||||
|
|
||||||
|
/// Port to expose Prometheus metrics
|
||||||
|
#[arg(long, default_value_t = 29000)]
|
||||||
|
prometheus_port: u16,
|
||||||
|
|
||||||
|
/// Host address to bind the Prometheus metrics server
|
||||||
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
|
prometheus_host: String,
|
||||||
|
|
||||||
|
/// Custom HTTP headers to check for request IDs
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
request_id_headers: Vec<String>,
|
||||||
|
|
||||||
|
/// Request timeout in seconds
|
||||||
|
#[arg(long, default_value_t = 600)]
|
||||||
|
request_timeout_secs: u64,
|
||||||
|
|
||||||
|
/// Maximum number of concurrent requests allowed
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
|
||||||
|
/// CORS allowed origins
|
||||||
|
#[arg(long, num_args = 0..)]
|
||||||
|
cors_allowed_origins: Vec<String>,
|
||||||
|
|
||||||
|
// Retry configuration
|
||||||
|
/// Maximum number of retries
|
||||||
|
#[arg(long, default_value_t = 3)]
|
||||||
|
retry_max_retries: u32,
|
||||||
|
|
||||||
|
/// Initial backoff in milliseconds for retries
|
||||||
|
#[arg(long, default_value_t = 100)]
|
||||||
|
retry_initial_backoff_ms: u64,
|
||||||
|
|
||||||
|
/// Maximum backoff in milliseconds for retries
|
||||||
|
#[arg(long, default_value_t = 10000)]
|
||||||
|
retry_max_backoff_ms: u64,
|
||||||
|
|
||||||
|
/// Backoff multiplier for exponential backoff
|
||||||
|
#[arg(long, default_value_t = 2.0)]
|
||||||
|
retry_backoff_multiplier: f32,
|
||||||
|
|
||||||
|
/// Jitter factor for retry backoff
|
||||||
|
#[arg(long, default_value_t = 0.1)]
|
||||||
|
retry_jitter_factor: f32,
|
||||||
|
|
||||||
|
/// Disable retries
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
disable_retries: bool,
|
||||||
|
|
||||||
|
// Circuit breaker configuration
|
||||||
|
/// Number of failures before circuit breaker opens
|
||||||
|
#[arg(long, default_value_t = 5)]
|
||||||
|
cb_failure_threshold: u32,
|
||||||
|
|
||||||
|
/// Number of successes before circuit breaker closes
|
||||||
|
#[arg(long, default_value_t = 2)]
|
||||||
|
cb_success_threshold: u32,
|
||||||
|
|
||||||
|
/// Timeout duration in seconds for circuit breaker
|
||||||
|
#[arg(long, default_value_t = 30)]
|
||||||
|
cb_timeout_duration_secs: u64,
|
||||||
|
|
||||||
|
/// Window duration in seconds for circuit breaker
|
||||||
|
#[arg(long, default_value_t = 60)]
|
||||||
|
cb_window_duration_secs: u64,
|
||||||
|
|
||||||
|
/// Disable circuit breaker
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
disable_circuit_breaker: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CliArgs {
|
||||||
|
/// Parse selector strings into HashMap
|
||||||
|
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for item in selector_list {
|
||||||
|
if let Some(eq_pos) = item.find('=') {
|
||||||
|
let key = item[..eq_pos].to_string();
|
||||||
|
let value = item[eq_pos + 1..].to_string();
|
||||||
|
map.insert(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert policy string to PolicyConfig
|
||||||
|
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
|
||||||
|
match policy_str {
|
||||||
|
"random" => PolicyConfig::Random,
|
||||||
|
"round_robin" => PolicyConfig::RoundRobin,
|
||||||
|
"cache_aware" => PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: self.cache_threshold,
|
||||||
|
balance_abs_threshold: self.balance_abs_threshold,
|
||||||
|
balance_rel_threshold: self.balance_rel_threshold,
|
||||||
|
eviction_interval_secs: self.eviction_interval,
|
||||||
|
max_tree_size: self.max_tree_size,
|
||||||
|
},
|
||||||
|
"power_of_two" => PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 5, // Default value
|
||||||
|
},
|
||||||
|
_ => PolicyConfig::RoundRobin, // Fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert CLI arguments to RouterConfig
|
||||||
|
fn to_router_config(
|
||||||
|
&self,
|
||||||
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
|
) -> ConfigResult<RouterConfig> {
|
||||||
|
// Determine routing mode
|
||||||
|
let mode = if self.pd_disaggregation {
|
||||||
|
let decode_urls = self.decode.clone();
|
||||||
|
|
||||||
|
// Validate PD configuration if not using service discovery
|
||||||
|
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||||
|
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular mode
|
||||||
|
if !self.service_discovery && self.worker_urls.is_empty() {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "Regular mode requires --worker-urls when not using service discovery"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: self.worker_urls.clone(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Main policy
|
||||||
|
let policy = self.parse_policy(&self.policy);
|
||||||
|
|
||||||
|
// Service discovery configuration
|
||||||
|
let discovery = if self.service_discovery {
|
||||||
|
Some(DiscoveryConfig {
|
||||||
|
enabled: true,
|
||||||
|
namespace: self.service_discovery_namespace.clone(),
|
||||||
|
port: self.service_discovery_port,
|
||||||
|
check_interval_secs: 60,
|
||||||
|
selector: Self::parse_selector(&self.selector),
|
||||||
|
prefill_selector: Self::parse_selector(&self.prefill_selector),
|
||||||
|
decode_selector: Self::parse_selector(&self.decode_selector),
|
||||||
|
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Metrics configuration
|
||||||
|
let metrics = Some(MetricsConfig {
|
||||||
|
port: self.prometheus_port,
|
||||||
|
host: self.prometheus_host.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Build RouterConfig
|
||||||
|
Ok(RouterConfig {
|
||||||
|
mode,
|
||||||
|
policy,
|
||||||
|
host: self.host.clone(),
|
||||||
|
port: self.port,
|
||||||
|
max_payload_size: self.max_payload_size,
|
||||||
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
|
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||||
|
dp_aware: self.dp_aware,
|
||||||
|
api_key: self.api_key.clone(),
|
||||||
|
discovery,
|
||||||
|
metrics,
|
||||||
|
log_dir: self.log_dir.clone(),
|
||||||
|
log_level: Some(self.log_level.clone()),
|
||||||
|
request_id_headers: if self.request_id_headers.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.request_id_headers.clone())
|
||||||
|
},
|
||||||
|
max_concurrent_requests: self.max_concurrent_requests,
|
||||||
|
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||||
|
retry: RetryConfig {
|
||||||
|
max_retries: self.retry_max_retries,
|
||||||
|
initial_backoff_ms: self.retry_initial_backoff_ms,
|
||||||
|
max_backoff_ms: self.retry_max_backoff_ms,
|
||||||
|
backoff_multiplier: self.retry_backoff_multiplier,
|
||||||
|
jitter_factor: self.retry_jitter_factor,
|
||||||
|
},
|
||||||
|
circuit_breaker: CircuitBreakerConfig {
|
||||||
|
failure_threshold: self.cb_failure_threshold,
|
||||||
|
success_threshold: self.cb_success_threshold,
|
||||||
|
timeout_duration_secs: self.cb_timeout_duration_secs,
|
||||||
|
window_duration_secs: self.cb_window_duration_secs,
|
||||||
|
},
|
||||||
|
disable_retries: self.disable_retries,
|
||||||
|
disable_circuit_breaker: self.disable_circuit_breaker,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create ServerConfig from CLI args and RouterConfig
|
||||||
|
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
|
||||||
|
// Create service discovery config if enabled
|
||||||
|
let service_discovery_config = if self.service_discovery {
|
||||||
|
Some(ServiceDiscoveryConfig {
|
||||||
|
enabled: true,
|
||||||
|
selector: Self::parse_selector(&self.selector),
|
||||||
|
check_interval: std::time::Duration::from_secs(60),
|
||||||
|
port: self.service_discovery_port,
|
||||||
|
namespace: self.service_discovery_namespace.clone(),
|
||||||
|
pd_mode: self.pd_disaggregation,
|
||||||
|
prefill_selector: Self::parse_selector(&self.prefill_selector),
|
||||||
|
decode_selector: Self::parse_selector(&self.decode_selector),
|
||||||
|
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create Prometheus config
|
||||||
|
let prometheus_config = Some(PrometheusConfig {
|
||||||
|
port: self.prometheus_port,
|
||||||
|
host: self.prometheus_host.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
ServerConfig {
|
||||||
|
host: self.host.clone(),
|
||||||
|
port: self.port,
|
||||||
|
router_config,
|
||||||
|
max_payload_size: self.max_payload_size,
|
||||||
|
log_dir: self.log_dir.clone(),
|
||||||
|
log_level: Some(self.log_level.clone()),
|
||||||
|
service_discovery_config,
|
||||||
|
prometheus_config,
|
||||||
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
|
request_id_headers: if self.request_id_headers.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.request_id_headers.clone())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Parse prefill arguments manually before clap parsing
|
||||||
|
let prefill_urls = parse_prefill_args();
|
||||||
|
|
||||||
|
// Filter out prefill arguments and their values before passing to clap
|
||||||
|
let mut filtered_args: Vec<String> = Vec::new();
|
||||||
|
let raw_args: Vec<String> = std::env::args().collect();
|
||||||
|
let mut i = 0;
|
||||||
|
|
||||||
|
while i < raw_args.len() {
|
||||||
|
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
|
||||||
|
// Skip --prefill and its URL
|
||||||
|
i += 2;
|
||||||
|
// Also skip bootstrap port if present
|
||||||
|
if i < raw_args.len() && !raw_args[i].starts_with("--") {
|
||||||
|
if raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none" {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
filtered_args.push(raw_args[i].clone());
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse CLI arguments with clap using filtered args
|
||||||
|
let cli_args = CliArgs::parse_from(filtered_args);
|
||||||
|
|
||||||
|
// Print startup info
|
||||||
|
println!("SGLang Router starting...");
|
||||||
|
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||||
|
println!(
|
||||||
|
"Mode: {}",
|
||||||
|
if cli_args.pd_disaggregation {
|
||||||
|
"PD Disaggregated"
|
||||||
|
} else {
|
||||||
|
"Regular"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
println!("Policy: {}", cli_args.policy);
|
||||||
|
|
||||||
|
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
||||||
|
println!("Prefill nodes: {:?}", prefill_urls);
|
||||||
|
println!("Decode nodes: {:?}", cli_args.decode);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to RouterConfig
|
||||||
|
let router_config = cli_args.to_router_config(prefill_urls)?;
|
||||||
|
|
||||||
|
// Validate configuration
|
||||||
|
router_config.validate()?;
|
||||||
|
|
||||||
|
// Create ServerConfig
|
||||||
|
let server_config = cli_args.to_server_config(router_config);
|
||||||
|
|
||||||
|
// Create a new runtime for the server (like Python binding does)
|
||||||
|
let runtime = tokio::runtime::Runtime::new()?;
|
||||||
|
|
||||||
|
// Block on the async startup function
|
||||||
|
runtime.block_on(async move { server::startup(server_config).await })?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -11,29 +11,32 @@ pub struct RouterFactory;
|
|||||||
|
|
||||||
impl RouterFactory {
|
impl RouterFactory {
|
||||||
/// Create a router instance from application context
|
/// Create a router instance from application context
|
||||||
pub fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
match &ctx.router_config.mode {
|
match &ctx.router_config.mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { worker_urls } => {
|
||||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||||
}
|
}
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
} => Self::create_pd_router(
|
} => {
|
||||||
prefill_urls,
|
Self::create_pd_router(
|
||||||
decode_urls,
|
prefill_urls,
|
||||||
prefill_policy.as_ref(),
|
decode_urls,
|
||||||
decode_policy.as_ref(),
|
prefill_policy.as_ref(),
|
||||||
&ctx.router_config.policy,
|
decode_policy.as_ref(),
|
||||||
ctx,
|
&ctx.router_config.policy,
|
||||||
),
|
ctx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a regular router with injected policy
|
/// Create a regular router with injected policy
|
||||||
fn create_regular_router(
|
async fn create_regular_router(
|
||||||
worker_urls: &[String],
|
worker_urls: &[String],
|
||||||
policy_config: &PolicyConfig,
|
policy_config: &PolicyConfig,
|
||||||
ctx: &Arc<AppContext>,
|
ctx: &Arc<AppContext>,
|
||||||
@@ -52,13 +55,14 @@ impl RouterFactory {
|
|||||||
ctx.router_config.api_key.clone(),
|
ctx.router_config.api_key.clone(),
|
||||||
ctx.router_config.retry.clone(),
|
ctx.router_config.retry.clone(),
|
||||||
ctx.router_config.circuit_breaker.clone(),
|
ctx.router_config.circuit_breaker.clone(),
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a PD router with injected policy
|
/// Create a PD router with injected policy
|
||||||
fn create_pd_router(
|
async fn create_pd_router(
|
||||||
prefill_urls: &[(String, Option<u16>)],
|
prefill_urls: &[(String, Option<u16>)],
|
||||||
decode_urls: &[String],
|
decode_urls: &[String],
|
||||||
prefill_policy_config: Option<&PolicyConfig>,
|
prefill_policy_config: Option<&PolicyConfig>,
|
||||||
@@ -83,7 +87,8 @@ impl RouterFactory {
|
|||||||
ctx.router_config.worker_startup_check_interval_secs,
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
ctx.router_config.retry.clone(),
|
ctx.router_config.retry.clone(),
|
||||||
ctx.router_config.circuit_breaker.clone(),
|
ctx.router_config.circuit_breaker.clone(),
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ impl PDRouter {
|
|||||||
self.timeout_secs,
|
self.timeout_secs,
|
||||||
self.interval_secs,
|
self.interval_secs,
|
||||||
)
|
)
|
||||||
|
.await
|
||||||
.map_err(|_| PDRouterError::HealthCheckFailed {
|
.map_err(|_| PDRouterError::HealthCheckFailed {
|
||||||
url: url.to_string(),
|
url: url.to_string(),
|
||||||
})
|
})
|
||||||
@@ -349,7 +350,7 @@ impl PDRouter {
|
|||||||
Ok(format!("Successfully removed decode server: {}", url))
|
Ok(format!("Successfully removed decode server: {}", url))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub async fn new(
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
decode_urls: Vec<String>,
|
decode_urls: Vec<String>,
|
||||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
@@ -392,7 +393,8 @@ impl PDRouter {
|
|||||||
&all_urls,
|
&all_urls,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize cache-aware policies with workers
|
// Initialize cache-aware policies with workers
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ use futures_util::StreamExt;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::thread;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
@@ -52,7 +51,7 @@ pub struct Router {
|
|||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
/// Create a new router with injected policy and client
|
/// Create a new router with injected policy and client
|
||||||
pub fn new(
|
pub async fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
client: Client,
|
client: Client,
|
||||||
@@ -68,7 +67,7 @@ impl Router {
|
|||||||
|
|
||||||
// Wait for workers to be healthy (skip if empty - for service discovery mode)
|
// Wait for workers to be healthy (skip if empty - for service discovery mode)
|
||||||
if !worker_urls.is_empty() {
|
if !worker_urls.is_empty() {
|
||||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let worker_urls = if dp_aware {
|
let worker_urls = if dp_aware {
|
||||||
@@ -156,7 +155,7 @@ impl Router {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_for_healthy_workers(
|
pub async fn wait_for_healthy_workers(
|
||||||
worker_urls: &[String],
|
worker_urls: &[String],
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
@@ -167,9 +166,24 @@ impl Router {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform health check asynchronously
|
||||||
|
Self::wait_for_healthy_workers_async(worker_urls, timeout_secs, interval_secs).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_healthy_workers_async(
|
||||||
|
worker_urls: &[String],
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!(
|
||||||
|
"Waiting for {} workers to become healthy (timeout: {}s)",
|
||||||
|
worker_urls.len(),
|
||||||
|
timeout_secs
|
||||||
|
);
|
||||||
|
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let sync_client = reqwest::blocking::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(timeout_secs))
|
.timeout(Duration::from_secs(2))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
@@ -185,20 +199,48 @@ impl Router {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform all health checks concurrently
|
||||||
|
let mut health_checks = Vec::new();
|
||||||
|
for url in worker_urls {
|
||||||
|
let client_clone = client.clone();
|
||||||
|
let url_clone = url.clone();
|
||||||
|
|
||||||
|
let check_health = tokio::spawn(async move {
|
||||||
|
let health_url = format!("{}/health", url_clone);
|
||||||
|
match client_clone.get(&health_url).send().await {
|
||||||
|
Ok(res) => {
|
||||||
|
if res.status().is_success() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some((url_clone, format!("status: {}", res.status())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => Some((url_clone, "not ready".to_string())),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
health_checks.push(check_health);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all health checks to complete
|
||||||
|
let results = futures::future::join_all(health_checks).await;
|
||||||
|
|
||||||
let mut all_healthy = true;
|
let mut all_healthy = true;
|
||||||
let mut unhealthy_workers = Vec::new();
|
let mut unhealthy_workers = Vec::new();
|
||||||
|
|
||||||
for url in worker_urls {
|
for result in results {
|
||||||
match sync_client.get(&format!("{}/health", url)).send() {
|
match result {
|
||||||
Ok(res) => {
|
Ok(None) => {
|
||||||
if !res.status().is_success() {
|
// Worker is healthy
|
||||||
all_healthy = false;
|
|
||||||
unhealthy_workers.push((url, format!("status: {}", res.status())));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Ok(Some((url, reason))) => {
|
||||||
all_healthy = false;
|
all_healthy = false;
|
||||||
unhealthy_workers.push((url, "not ready".to_string()));
|
unhealthy_workers.push((url, reason));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
all_healthy = false;
|
||||||
|
unhealthy_workers
|
||||||
|
.push(("unknown".to_string(), format!("task error: {}", e)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,11 +250,12 @@ impl Router {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
debug!(
|
debug!(
|
||||||
"Waiting for {} workers to become healthy ({} unhealthy)",
|
"Waiting for {} workers to become healthy ({} unhealthy: {:?})",
|
||||||
worker_urls.len(),
|
worker_urls.len(),
|
||||||
unhealthy_workers.len()
|
unhealthy_workers.len(),
|
||||||
|
unhealthy_workers
|
||||||
);
|
);
|
||||||
thread::sleep(Duration::from_secs(interval_secs));
|
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1246,19 +1289,19 @@ mod tests {
|
|||||||
assert_eq!(result.unwrap(), "http://worker1:8080");
|
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_wait_for_healthy_workers_empty_list() {
|
async fn test_wait_for_healthy_workers_empty_list() {
|
||||||
// Empty list will timeout as there are no workers to check
|
// Empty list will return error immediately
|
||||||
let result = Router::wait_for_healthy_workers(&[], 1, 1);
|
let result = Router::wait_for_healthy_workers(&[], 1, 1).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(result.unwrap_err().contains("Timeout"));
|
assert!(result.unwrap_err().contains("no workers provided"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_wait_for_healthy_workers_invalid_urls() {
|
async fn test_wait_for_healthy_workers_invalid_urls() {
|
||||||
// This test will timeout quickly since the URLs are invalid
|
// This test will timeout quickly since the URLs are invalid
|
||||||
let result =
|
let result =
|
||||||
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
|
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(result.unwrap_err().contains("Timeout"));
|
assert!(result.unwrap_err().contains("Timeout"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
));
|
));
|
||||||
|
|
||||||
// Create router with the context
|
// Create router with the context
|
||||||
let router = RouterFactory::create_router(&app_context)?;
|
let router = RouterFactory::create_router(&app_context).await?;
|
||||||
|
|
||||||
// Create app state with router and context
|
// Create app state with router and context
|
||||||
let app_state = Arc::new(AppState {
|
let app_state = Arc::new(AppState {
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create a Router instance for testing event handlers
|
// Helper to create a Router instance for testing event handlers
|
||||||
fn create_test_router() -> Arc<dyn RouterTrait> {
|
async fn create_test_router() -> Arc<dyn RouterTrait> {
|
||||||
use crate::config::PolicyConfig;
|
use crate::config::PolicyConfig;
|
||||||
use crate::policies::PolicyFactory;
|
use crate::policies::PolicyFactory;
|
||||||
use crate::routers::router::Router;
|
use crate::routers::router::Router;
|
||||||
@@ -593,6 +593,7 @@ mod tests {
|
|||||||
crate::config::types::RetryConfig::default(),
|
crate::config::types::RetryConfig::default(),
|
||||||
crate::config::types::CircuitBreakerConfig::default(),
|
crate::config::types::CircuitBreakerConfig::default(),
|
||||||
)
|
)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Arc::new(router) as Arc<dyn RouterTrait>
|
Arc::new(router) as Arc<dyn RouterTrait>
|
||||||
}
|
}
|
||||||
@@ -896,7 +897,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pod_event_add_unhealthy_pod() {
|
async fn test_handle_pod_event_add_unhealthy_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "pod1".into(),
|
name: "pod1".into(),
|
||||||
@@ -925,7 +926,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pod_deletion_non_existing_pod() {
|
async fn test_handle_pod_deletion_non_existing_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "pod1".into(),
|
name: "pod1".into(),
|
||||||
@@ -952,7 +953,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pd_pod_event_prefill_pod() {
|
async fn test_handle_pd_pod_event_prefill_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "prefill-pod".into(),
|
name: "prefill-pod".into(),
|
||||||
@@ -981,7 +982,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pd_pod_event_decode_pod() {
|
async fn test_handle_pd_pod_event_decode_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "decode-pod".into(),
|
name: "decode-pod".into(),
|
||||||
@@ -1008,7 +1009,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pd_pod_deletion_tracked_pod() {
|
async fn test_handle_pd_pod_deletion_tracked_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "test-pod".into(),
|
name: "test-pod".into(),
|
||||||
@@ -1042,7 +1043,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pd_pod_deletion_untracked_pod() {
|
async fn test_handle_pd_pod_deletion_untracked_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "untracked-pod".into(),
|
name: "untracked-pod".into(),
|
||||||
@@ -1071,7 +1072,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unified_handler_regular_mode() {
|
async fn test_unified_handler_regular_mode() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "regular-pod".into(),
|
name: "regular-pod".into(),
|
||||||
@@ -1099,7 +1100,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unified_handler_pd_mode_with_prefill() {
|
async fn test_unified_handler_pd_mode_with_prefill() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "prefill-pod".into(),
|
name: "prefill-pod".into(),
|
||||||
@@ -1127,7 +1128,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unified_handler_deletion_with_pd_mode() {
|
async fn test_unified_handler_deletion_with_pd_mode() {
|
||||||
let router = create_test_router();
|
let router = create_test_router().await;
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
let pod_info = PodInfo {
|
let pod_info = PodInfo {
|
||||||
name: "decode-pod".into(),
|
name: "decode-pod".into(),
|
||||||
|
|||||||
@@ -92,12 +92,8 @@ impl TestContext {
|
|||||||
// Create app context
|
// Create app context
|
||||||
let app_context = common::create_test_context(config.clone());
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
// Create router using sync factory in a blocking context
|
// Create router
|
||||||
let router =
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
// Wait for router to discover workers
|
// Wait for router to discover workers
|
||||||
@@ -1451,10 +1447,7 @@ mod pd_mode_tests {
|
|||||||
let app_context = common::create_test_context(config);
|
let app_context = common::create_test_context(config);
|
||||||
|
|
||||||
// Create router - this might fail due to health check issues
|
// Create router - this might fail due to health check issues
|
||||||
let router_result =
|
let router_result = RouterFactory::create_router(&app_context).await;
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Clean up workers
|
// Clean up workers
|
||||||
prefill_worker.stop().await;
|
prefill_worker.stop().await;
|
||||||
|
|||||||
@@ -60,11 +60,7 @@ impl TestContext {
|
|||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
|
|
||||||
let app_context = common::create_test_context(config);
|
let app_context = common::create_test_context(config);
|
||||||
let router =
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
if !workers.is_empty() {
|
if !workers.is_empty() {
|
||||||
|
|||||||
@@ -61,11 +61,7 @@ impl TestContext {
|
|||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
|
|
||||||
let app_context = common::create_test_context(config);
|
let app_context = common::create_test_context(config);
|
||||||
let router =
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
if !workers.is_empty() {
|
if !workers.is_empty() {
|
||||||
|
|||||||
@@ -109,8 +109,8 @@ mod test_pd_routing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_pd_router_configuration() {
|
async fn test_pd_router_configuration() {
|
||||||
// Test PD router configuration with various policies
|
// Test PD router configuration with various policies
|
||||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
@@ -190,7 +190,7 @@ mod test_pd_routing {
|
|||||||
let app_context =
|
let app_context =
|
||||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
|
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
|
||||||
let app_context = std::sync::Arc::new(app_context);
|
let app_context = std::sync::Arc::new(app_context);
|
||||||
let result = RouterFactory::create_router(&app_context);
|
let result = RouterFactory::create_router(&app_context).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
let error_msg = result.unwrap_err();
|
let error_msg = result.unwrap_err();
|
||||||
// Error should be about health/timeout, not configuration
|
// Error should be about health/timeout, not configuration
|
||||||
|
|||||||
Reference in New Issue
Block a user