[router] remove old/oudated/useless comments (#10967)

This commit is contained in:
Simo Lin
2025-09-26 12:45:15 -04:00
committed by GitHub
parent be059b83d6
commit a7fe6e10a1
5 changed files with 28 additions and 306 deletions

View File

@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
// Helper function to parse prefill arguments from command line
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new();
@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
if args[i] == "--prefill" && i + 1 < args.len() {
let url = args[i + 1].clone();
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
// Check if next arg is a port number
if let Ok(port) = args[i + 2].parse::<u16>() {
i += 1; // Skip the port argument
i += 1;
Some(port)
} else if args[i + 2].to_lowercase() == "none" {
i += 1; // Skip the "none" argument
i += 1;
None
} else {
None
@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
None
};
prefill_entries.push((url, bootstrap_port));
i += 2; // Skip --prefill and URL
i += 2;
} else {
i += 1;
}
@@ -101,252 +99,186 @@ Examples:
"#)]
struct CliArgs {
/// Host address to bind the router server
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// Port number to bind the router server
#[arg(long, default_value_t = 30000)]
port: u16,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long, num_args = 0..)]
worker_urls: Vec<String>,
/// Load balancing policy to use
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
policy: String,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long, default_value_t = false)]
pd_disaggregation: bool,
/// Decode server URL (can be specified multiple times)
#[arg(long, action = ArgAction::Append)]
decode: Vec<String>,
/// Specific policy for prefill nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
prefill_policy: Option<String>,
/// Specific policy for decode nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
decode_policy: Option<String>,
/// Timeout in seconds for worker startup
#[arg(long, default_value_t = 600)]
worker_startup_timeout_secs: u64,
/// Interval in seconds between checks for worker startup
#[arg(long, default_value_t = 30)]
worker_startup_check_interval: u64,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long, default_value_t = 0.3)]
cache_threshold: f32,
/// Absolute threshold for load balancing
#[arg(long, default_value_t = 64)]
balance_abs_threshold: usize,
/// Relative threshold for load balancing
#[arg(long, default_value_t = 1.5)]
balance_rel_threshold: f32,
/// Interval in seconds between cache eviction operations
#[arg(long, default_value_t = 120)]
eviction_interval: u64,
/// Maximum size of the approximation tree for cache-aware routing
#[arg(long, default_value_t = 67108864)] // 2^26
#[arg(long, default_value_t = 67108864)]
max_tree_size: usize,
/// Maximum payload size in bytes
#[arg(long, default_value_t = 536870912)] // 512MB
#[arg(long, default_value_t = 536870912)]
max_payload_size: usize,
/// Enable data parallelism aware schedule
#[arg(long, default_value_t = false)]
dp_aware: bool,
/// API key for worker authorization
#[arg(long)]
api_key: Option<String>,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
backend: Backend,
/// Directory to store log files
#[arg(long)]
log_dir: Option<String>,
/// Set the logging level
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
log_level: String,
/// Enable Kubernetes service discovery
#[arg(long, default_value_t = false)]
service_discovery: bool,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long, num_args = 0..)]
selector: Vec<String>,
/// Port to use for discovered worker pods
#[arg(long, default_value_t = 80)]
service_discovery_port: u16,
/// Kubernetes namespace to watch for pods
#[arg(long)]
service_discovery_namespace: Option<String>,
/// Label selector for prefill server pods in PD mode
#[arg(long, num_args = 0..)]
prefill_selector: Vec<String>,
/// Label selector for decode server pods in PD mode
#[arg(long, num_args = 0..)]
decode_selector: Vec<String>,
/// Port to expose Prometheus metrics
#[arg(long, default_value_t = 29000)]
prometheus_port: u16,
/// Host address to bind the Prometheus metrics server
#[arg(long, default_value = "127.0.0.1")]
prometheus_host: String,
/// Custom HTTP headers to check for request IDs
#[arg(long, num_args = 0..)]
request_id_headers: Vec<String>,
/// Request timeout in seconds
#[arg(long, default_value_t = 1800)]
request_timeout_secs: u64,
/// Maximum number of concurrent requests allowed
#[arg(long, default_value_t = 256)]
max_concurrent_requests: usize,
/// CORS allowed origins
#[arg(long, num_args = 0..)]
cors_allowed_origins: Vec<String>,
// Retry configuration
/// Maximum number of retries
#[arg(long, default_value_t = 5)]
retry_max_retries: u32,
/// Initial backoff in milliseconds for retries
#[arg(long, default_value_t = 50)]
retry_initial_backoff_ms: u64,
/// Maximum backoff in milliseconds for retries
#[arg(long, default_value_t = 30000)]
retry_max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
#[arg(long, default_value_t = 1.5)]
retry_backoff_multiplier: f32,
/// Jitter factor for retry backoff
#[arg(long, default_value_t = 0.2)]
retry_jitter_factor: f32,
/// Disable retries
#[arg(long, default_value_t = false)]
disable_retries: bool,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long, default_value_t = 10)]
cb_failure_threshold: u32,
/// Number of successes before circuit breaker closes
#[arg(long, default_value_t = 3)]
cb_success_threshold: u32,
/// Timeout duration in seconds for circuit breaker
#[arg(long, default_value_t = 60)]
cb_timeout_duration_secs: u64,
/// Window duration in seconds for circuit breaker
#[arg(long, default_value_t = 120)]
cb_window_duration_secs: u64,
/// Disable circuit breaker
#[arg(long, default_value_t = false)]
disable_circuit_breaker: bool,
// Health check configuration
/// Number of consecutive health check failures before marking worker unhealthy
#[arg(long, default_value_t = 3)]
health_failure_threshold: u32,
/// Number of consecutive health check successes before marking worker healthy
#[arg(long, default_value_t = 2)]
health_success_threshold: u32,
/// Timeout in seconds for health check requests
#[arg(long, default_value_t = 5)]
health_check_timeout_secs: u64,
/// Interval in seconds between runtime health checks
#[arg(long, default_value_t = 60)]
health_check_interval_secs: u64,
/// Health check endpoint path
#[arg(long, default_value = "/health")]
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long, default_value_t = false)]
enable_igw: bool,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)]
model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)]
tokenizer_path: Option<String>,
/// History backend configuration (memory, none, or oracle)
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String,
/// Directory containing the Oracle ATP wallet/config files (optional)
#[arg(long, env = "ATP_WALLET_PATH")]
oracle_wallet_path: Option<String>,
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
#[arg(long, env = "ATP_TNS_ALIAS")]
oracle_tns_alias: Option<String>,
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
#[arg(long, env = "ATP_DSN")]
oracle_dsn: Option<String>,
/// Oracle ATP username
#[arg(long, env = "ATP_USER")]
oracle_user: Option<String>,
/// Oracle ATP password
#[arg(long, env = "ATP_PASSWORD")]
oracle_password: Option<String>,
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
#[arg(long, env = "ATP_POOL_MIN")]
oracle_pool_min: Option<usize>,
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
#[arg(long, env = "ATP_POOL_MAX")]
oracle_pool_max: Option<usize>,
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>,
}
@@ -357,19 +289,15 @@ enum OracleConnectSource {
}
impl CliArgs {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc;
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
ConnectionMode::Http
}
/// Parse selector strings into HashMap
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
let mut map = HashMap::new();
for item in selector_list {
@@ -382,7 +310,6 @@ impl CliArgs {
map
}
/// Convert policy string to PolicyConfig
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
match policy_str {
"random" => PolicyConfig::Random,
@@ -395,9 +322,9 @@ impl CliArgs {
max_tree_size: self.max_tree_size,
},
"power_of_two" => PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
load_check_interval_secs: 5,
},
_ => PolicyConfig::RoundRobin, // Fallback
_ => PolicyConfig::RoundRobin,
}
}
@@ -482,26 +409,21 @@ impl CliArgs {
})
}
/// Convert CLI arguments to RouterConfig
fn to_router_config(
&self,
prefill_urls: Vec<(String, Option<u16>)>,
) -> ConfigResult<RouterConfig> {
// Determine routing mode
let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if matches!(self.backend, Backend::Openai) {
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode::OpenAI {
worker_urls: self.worker_urls.clone(),
}
} else if self.pd_disaggregation {
let decode_urls = self.decode.clone();
// Validate PD configuration if not using service discovery
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
return Err(ConfigError::ValidationFailed {
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
@@ -515,7 +437,6 @@ impl CliArgs {
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
}
} else {
// Regular mode
if !self.service_discovery && self.worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "Regular mode requires --worker-urls when not using service discovery"
@@ -527,10 +448,8 @@ impl CliArgs {
}
};
// Main policy
let policy = self.parse_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
Some(DiscoveryConfig {
enabled: true,
@@ -546,13 +465,11 @@ impl CliArgs {
None
};
// Metrics configuration
let metrics = Some(MetricsConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
});
// Determine connection mode from all worker URLs
let mut all_urls = Vec::new();
match &mode {
RoutingMode::Regular { worker_urls } => {
@@ -568,9 +485,7 @@ impl CliArgs {
}
all_urls.extend(decode_urls.clone());
}
RoutingMode::OpenAI { .. } => {
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
RoutingMode::OpenAI { .. } => {}
}
let connection_mode = match &mode {
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
@@ -589,7 +504,6 @@ impl CliArgs {
None
};
// Build RouterConfig
Ok(RouterConfig {
mode,
policy,
@@ -612,8 +526,8 @@ impl CliArgs {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: 100, // Default queue size
queue_timeout_secs: 60, // Default timeout
queue_size: 100,
queue_timeout_secs: 60,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
max_retries: self.retry_max_retries,
@@ -646,9 +560,7 @@ impl CliArgs {
})
}
/// Create ServerConfig from CLI args and RouterConfig
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery {
Some(ServiceDiscoveryConfig {
enabled: true,
@@ -665,7 +577,6 @@ impl CliArgs {
None
};
// Create Prometheus config
let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
@@ -691,19 +602,15 @@ impl CliArgs {
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse prefill arguments manually before clap parsing
let prefill_urls = parse_prefill_args();
// Filter out prefill arguments and their values before passing to clap
let mut filtered_args: Vec<String> = Vec::new();
let raw_args: Vec<String> = std::env::args().collect();
let mut i = 0;
while i < raw_args.len() {
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
// Skip --prefill and its URL
i += 2;
// Also skip bootstrap port if present
if i < raw_args.len()
&& !raw_args[i].starts_with("--")
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
// Parse CLI arguments with clap using filtered args
let cli_args = CliArgs::parse_from(filtered_args);
// Print startup info
println!("SGLang Router starting...");
println!("Host: {}:{}", cli_args.host, cli_args.port);
let mode_str = if cli_args.enable_igw {
@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};
println!("Mode: {}", mode_str);
// Warn for runtimes that are parsed but not yet implemented
match cli_args.backend {
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
println!(
@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
}
}
// Convert to RouterConfig
let router_config = cli_args.to_router_config(prefill_urls)?;
// Validate configuration
router_config.validate()?;
// Create ServerConfig
let server_config = cli_args.to_server_config(router_config);
// Create a new runtime for the server (like Python binding does)
let runtime = tokio::runtime::Runtime::new()?;
// Block on the async startup function
runtime.block_on(async move { server::startup(server_config).await })?;
Ok(())