From a7fe6e10a1c7c6b6be5dbd2aee54fb9ca4e08378 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 26 Sep 2025 12:45:15 -0400 Subject: [PATCH] [router] remove old/oudated/useless comments (#10967) --- sgl-router/src/lib.rs | 40 ++------- sgl-router/src/logging.rs | 35 +------- sgl-router/src/main.rs | 125 +++------------------------- sgl-router/src/metrics.rs | 62 +------------- sgl-router/src/service_discovery.rs | 72 +++------------- 5 files changed, 28 insertions(+), 306 deletions(-) diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index a1d2cabc4..43225d4cc 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -67,58 +67,47 @@ struct Router { decode_policy: Option, max_concurrent_requests: usize, cors_allowed_origins: Vec, - // Retry configuration retry_max_retries: u32, retry_initial_backoff_ms: u64, retry_max_backoff_ms: u64, retry_backoff_multiplier: f32, retry_jitter_factor: f32, disable_retries: bool, - // Circuit breaker configuration cb_failure_threshold: u32, cb_success_threshold: u32, cb_timeout_duration_secs: u64, cb_window_duration_secs: u64, disable_circuit_breaker: bool, - // Health check configuration health_failure_threshold: u32, health_success_threshold: u32, health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, - // IGW (Inference Gateway) configuration enable_igw: bool, queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, - // Connection mode (determined from worker URLs) connection_mode: config::ConnectionMode, - // Model path for tokenizer model_path: Option, - // Explicit tokenizer path tokenizer_path: Option, } impl Router { /// Determine connection mode from worker URLs fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode { - // Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { return config::ConnectionMode::Grpc; } } - // Default to HTTP for all other cases (including http://, https://, or no scheme) config::ConnectionMode::Http } - /// Convert PyO3 Router to RouterConfig pub fn to_router_config(&self) -> config::ConfigResult { use config::{ DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, }; - // Convert policy helper function let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig { match policy { PolicyType::Random => ConfigPolicyConfig::Random, @@ -131,14 +120,12 @@ impl Router { max_tree_size: self.max_tree_size, }, PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { - load_check_interval_secs: 5, // Default value + load_check_interval_secs: 5, }, } }; - // Determine routing mode let mode = if self.enable_igw { - // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder RoutingMode::Regular { worker_urls: vec![], } @@ -155,10 +142,8 @@ impl Router { } }; - // Convert main policy let policy = convert_policy(&self.policy); - // Service discovery configuration let discovery = if self.service_discovery { Some(DiscoveryConfig { enabled: true, @@ -174,7 +159,6 @@ impl Router { None }; - // Metrics configuration let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) { (Some(port), Some(host)) => Some(MetricsConfig { port, @@ -251,7 +235,7 @@ impl Router { balance_rel_threshold = 1.5, eviction_interval_secs = 120, max_tree_size = 2usize.pow(26), - max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches + max_payload_size = 512 * 1024 * 1024, dp_aware = false, api_key = None, log_dir = None, @@ -265,40 +249,35 @@ impl Router { bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"), prometheus_port = None, prometheus_host = None, - request_timeout_secs = 1800, // Add configurable request timeout - request_id_headers = None, // Custom request ID headers - pd_disaggregation = false, // New flag for PD mode + request_timeout_secs = 1800, + request_id_headers = None, + pd_disaggregation = false, prefill_urls = None, decode_urls = None, prefill_policy = None, decode_policy = None, max_concurrent_requests = 256, cors_allowed_origins = vec![], - // Retry defaults retry_max_retries = 5, retry_initial_backoff_ms = 50, retry_max_backoff_ms = 30_000, retry_backoff_multiplier = 1.5, retry_jitter_factor = 0.2, disable_retries = false, - // Circuit breaker defaults cb_failure_threshold = 10, cb_success_threshold = 3, cb_timeout_duration_secs = 60, cb_window_duration_secs = 120, disable_circuit_breaker = false, - // Health check defaults health_failure_threshold = 3, health_success_threshold = 2, health_check_timeout_secs = 5, health_check_interval_secs = 60, health_check_endpoint = String::from("/health"), - // IGW defaults enable_igw = false, queue_size = 100, queue_timeout_secs = 60, rate_limit_tokens_per_second = None, - // Tokenizer defaults model_path = None, tokenizer_path = None, ))] @@ -361,17 +340,14 @@ impl Router { model_path: Option, tokenizer_path: Option, ) -> PyResult { - // Determine connection mode from worker URLs let mut all_urls = worker_urls.clone(); - // Add prefill URLs if in PD mode if let Some(ref prefill_urls) = prefill_urls { for (url, _) in prefill_urls { all_urls.push(url.clone()); } } - // Add decode URLs if in PD mode if let Some(ref decode_urls) = decode_urls { all_urls.extend(decode_urls.clone()); } @@ -440,12 +416,10 @@ impl Router { } fn start(&self) -> PyResult<()> { - // Convert to RouterConfig and validate let router_config = self.to_router_config().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e)) })?; - // Validate the configuration router_config.validate().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "Configuration validation failed: {}", @@ -453,7 +427,6 @@ impl Router { )) })?; - // Create service discovery config if enabled let service_discovery_config = if self.service_discovery { Some(service_discovery::ServiceDiscoveryConfig { enabled: true, @@ -470,7 +443,6 @@ impl Router { None }; - // Create Prometheus config if enabled let prometheus_config = Some(PrometheusConfig { port: self.prometheus_port.unwrap_or(29000), host: self @@ -479,11 +451,9 @@ impl Router { .unwrap_or_else(|| "127.0.0.1".to_string()), }); - // Use tokio runtime instead of actix-web System for better compatibility let runtime = tokio::runtime::Runtime::new() .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - // Block on the async startup function runtime.block_on(async move { server::startup(server::ServerConfig { host: self.host.clone(), diff --git a/sgl-router/src/logging.rs b/sgl-router/src/logging.rs index 5c5b63e0e..c92139ec0 100644 --- a/sgl-router/src/logging.rs +++ b/sgl-router/src/logging.rs @@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; -/// Configuration for the logging system #[derive(Debug, Clone)] pub struct LoggingConfig { - /// Log level for the application (default: INFO) pub level: Level, - /// Whether to use json format for logs (default: false) pub json_format: bool, - /// Path to store log files. If None, logs will only go to stdout/stderr pub log_dir: Option, - /// Whether to colorize logs when output is a terminal (default: true) pub colorize: bool, - /// Log file name to use if log_dir is specified (default: "sgl-router") pub log_file_name: String, - /// Custom log targets to filter (default: "sglang_router_rs") pub log_targets: Option>, } @@ -38,30 +31,14 @@ impl Default for LoggingConfig { } } -/// Guard that keeps the file appender worker thread alive -/// -/// This must be kept in scope for the duration of the program -/// to ensure logs are properly written to files #[allow(dead_code)] pub struct LogGuard { _file_guard: Option, } -/// Initialize the logging system with the given configuration -/// -/// # Arguments -/// * `config` - Configuration for the logging system -/// -/// # Returns -/// A LogGuard that must be kept alive for the duration of the program -/// -/// # Panics -/// Will not panic, as initialization errors are handled gracefully pub fn init_logging(config: LoggingConfig) -> LogGuard { - // Forward logs to tracing - ignore errors to allow for multiple initialization let _ = LogTracer::init(); - // Convert log level to filter string let level_filter = match config.level { Level::TRACE => "trace", Level::DEBUG => "debug", @@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { Level::ERROR => "error", }; - // Create env filter let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { - // Format: =,=,... let filter_string = if let Some(targets) = &config.log_targets { targets .iter() @@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { EnvFilter::new(filter_string) }); - // Setup stdout/stderr layer let mut layers = Vec::new(); - // Standard timestamp format: YYYY-MM-DD HH:MM:SS let time_format = "%Y-%m-%d %H:%M:%S".to_string(); - // Configure the console stdout layer let stdout_layer = tracing_subscriber::fmt::layer() .with_ansi(config.colorize) .with_file(true) @@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { layers.push(stdout_layer); - // Create a file appender if log_dir is specified let mut file_guard = None; if let Some(log_dir) = &config.log_dir { let file_name = config.log_file_name.clone(); let log_dir = PathBuf::from(log_dir); - // Create log directory if it doesn't exist if !log_dir.exists() { if let Err(e) = std::fs::create_dir_all(&log_dir) { eprintln!("Failed to create log directory: {}", e); @@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { file_guard = Some(guard); let file_layer = tracing_subscriber::fmt::layer() - .with_ansi(false) // Never use ANSI colors in log files + .with_ansi(false) .with_file(true) .with_line_number(true) .with_timer(ChronoUtc::new(time_format)) @@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { layers.push(file_layer); } - // Initialize the subscriber with all layers - // Use try_init to handle errors gracefully in case another subscriber is already set let _ = tracing_subscriber::registry() .with(env_filter) .with(layers) .try_init(); - // Return the guard to keep the file appender worker thread alive LogGuard { _file_guard: file_guard, } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 3380ab8a1..888502b9e 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig}; use sglang_router_rs::service_discovery::ServiceDiscoveryConfig; use std::collections::HashMap; -// Helper function to parse prefill arguments from command line fn parse_prefill_args() -> Vec<(String, Option)> { let args: Vec = std::env::args().collect(); let mut prefill_entries = Vec::new(); @@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option)> { if args[i] == "--prefill" && i + 1 < args.len() { let url = args[i + 1].clone(); let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") { - // Check if next arg is a port number if let Ok(port) = args[i + 2].parse::() { - i += 1; // Skip the port argument + i += 1; Some(port) } else if args[i + 2].to_lowercase() == "none" { - i += 1; // Skip the "none" argument + i += 1; None } else { None @@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option)> { None }; prefill_entries.push((url, bootstrap_port)); - i += 2; // Skip --prefill and URL + i += 2; } else { i += 1; } @@ -101,252 +99,186 @@ Examples: "#)] struct CliArgs { - /// Host address to bind the router server #[arg(long, default_value = "127.0.0.1")] host: String, - /// Port number to bind the router server #[arg(long, default_value_t = 30000)] port: u16, - /// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000) #[arg(long, num_args = 0..)] worker_urls: Vec, - /// Load balancing policy to use #[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] policy: String, - /// Enable PD (Prefill-Decode) disaggregated mode #[arg(long, default_value_t = false)] pd_disaggregation: bool, - /// Decode server URL (can be specified multiple times) #[arg(long, action = ArgAction::Append)] decode: Vec, - /// Specific policy for prefill nodes in PD mode #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] prefill_policy: Option, - /// Specific policy for decode nodes in PD mode #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] decode_policy: Option, - /// Timeout in seconds for worker startup #[arg(long, default_value_t = 600)] worker_startup_timeout_secs: u64, - /// Interval in seconds between checks for worker startup #[arg(long, default_value_t = 30)] worker_startup_check_interval: u64, - /// Cache threshold (0.0-1.0) for cache-aware routing #[arg(long, default_value_t = 0.3)] cache_threshold: f32, - /// Absolute threshold for load balancing #[arg(long, default_value_t = 64)] balance_abs_threshold: usize, - /// Relative threshold for load balancing #[arg(long, default_value_t = 1.5)] balance_rel_threshold: f32, - /// Interval in seconds between cache eviction operations #[arg(long, default_value_t = 120)] eviction_interval: u64, - /// Maximum size of the approximation tree for cache-aware routing - #[arg(long, default_value_t = 67108864)] // 2^26 + #[arg(long, default_value_t = 67108864)] max_tree_size: usize, - /// Maximum payload size in bytes - #[arg(long, default_value_t = 536870912)] // 512MB + #[arg(long, default_value_t = 536870912)] max_payload_size: usize, - /// Enable data parallelism aware schedule #[arg(long, default_value_t = false)] dp_aware: bool, - /// API key for worker authorization #[arg(long)] api_key: Option, - /// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic) #[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")] backend: Backend, - /// Directory to store log files #[arg(long)] log_dir: Option, - /// Set the logging level #[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])] log_level: String, - /// Enable Kubernetes service discovery #[arg(long, default_value_t = false)] service_discovery: bool, - /// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2) #[arg(long, num_args = 0..)] selector: Vec, - /// Port to use for discovered worker pods #[arg(long, default_value_t = 80)] service_discovery_port: u16, - /// Kubernetes namespace to watch for pods #[arg(long)] service_discovery_namespace: Option, - /// Label selector for prefill server pods in PD mode #[arg(long, num_args = 0..)] prefill_selector: Vec, - /// Label selector for decode server pods in PD mode #[arg(long, num_args = 0..)] decode_selector: Vec, - /// Port to expose Prometheus metrics #[arg(long, default_value_t = 29000)] prometheus_port: u16, - /// Host address to bind the Prometheus metrics server #[arg(long, default_value = "127.0.0.1")] prometheus_host: String, - /// Custom HTTP headers to check for request IDs #[arg(long, num_args = 0..)] request_id_headers: Vec, - /// Request timeout in seconds #[arg(long, default_value_t = 1800)] request_timeout_secs: u64, - /// Maximum number of concurrent requests allowed #[arg(long, default_value_t = 256)] max_concurrent_requests: usize, - /// CORS allowed origins #[arg(long, num_args = 0..)] cors_allowed_origins: Vec, - // Retry configuration - /// Maximum number of retries #[arg(long, default_value_t = 5)] retry_max_retries: u32, - /// Initial backoff in milliseconds for retries #[arg(long, default_value_t = 50)] retry_initial_backoff_ms: u64, - /// Maximum backoff in milliseconds for retries #[arg(long, default_value_t = 30000)] retry_max_backoff_ms: u64, - /// Backoff multiplier for exponential backoff #[arg(long, default_value_t = 1.5)] retry_backoff_multiplier: f32, - /// Jitter factor for retry backoff #[arg(long, default_value_t = 0.2)] retry_jitter_factor: f32, - /// Disable retries #[arg(long, default_value_t = false)] disable_retries: bool, - // Circuit breaker configuration - /// Number of failures before circuit breaker opens #[arg(long, default_value_t = 10)] cb_failure_threshold: u32, - /// Number of successes before circuit breaker closes #[arg(long, default_value_t = 3)] cb_success_threshold: u32, - /// Timeout duration in seconds for circuit breaker #[arg(long, default_value_t = 60)] cb_timeout_duration_secs: u64, - /// Window duration in seconds for circuit breaker #[arg(long, default_value_t = 120)] cb_window_duration_secs: u64, - /// Disable circuit breaker #[arg(long, default_value_t = false)] disable_circuit_breaker: bool, - // Health check configuration - /// Number of consecutive health check failures before marking worker unhealthy #[arg(long, default_value_t = 3)] health_failure_threshold: u32, - /// Number of consecutive health check successes before marking worker healthy #[arg(long, default_value_t = 2)] health_success_threshold: u32, - /// Timeout in seconds for health check requests #[arg(long, default_value_t = 5)] health_check_timeout_secs: u64, - /// Interval in seconds between runtime health checks #[arg(long, default_value_t = 60)] health_check_interval_secs: u64, - /// Health check endpoint path #[arg(long, default_value = "/health")] health_check_endpoint: String, - // IGW (Inference Gateway) configuration - /// Enable Inference Gateway mode #[arg(long, default_value_t = false)] enable_igw: bool, - // Tokenizer configuration - /// Model path for loading tokenizer (HuggingFace model ID or local path) #[arg(long)] model_path: Option, - /// Explicit tokenizer path (overrides model_path tokenizer if provided) #[arg(long)] tokenizer_path: Option, - /// History backend configuration (memory, none, or oracle) #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] history_backend: String, - /// Directory containing the Oracle ATP wallet/config files (optional) #[arg(long, env = "ATP_WALLET_PATH")] oracle_wallet_path: Option, - /// Wallet TNS alias to use (e.g. `_low`) #[arg(long, env = "ATP_TNS_ALIAS")] oracle_tns_alias: Option, - /// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`) #[arg(long, env = "ATP_DSN")] oracle_dsn: Option, - /// Oracle ATP username #[arg(long, env = "ATP_USER")] oracle_user: Option, - /// Oracle ATP password #[arg(long, env = "ATP_PASSWORD")] oracle_password: Option, - /// Minimum number of pooled ATP connections (defaults to 1 when omitted) #[arg(long, env = "ATP_POOL_MIN")] oracle_pool_min: Option, - /// Maximum number of pooled ATP connections (defaults to 16 when omitted) #[arg(long, env = "ATP_POOL_MAX")] oracle_pool_max: Option, - /// Connection acquisition timeout in seconds (defaults to 30 when omitted) #[arg(long, env = "ATP_POOL_TIMEOUT_SECS")] oracle_pool_timeout_secs: Option, } @@ -357,19 +289,15 @@ enum OracleConnectSource { } impl CliArgs { - /// Determine connection mode from worker URLs fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode { - // Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { return ConnectionMode::Grpc; } } - // Default to HTTP for all other cases (including http://, https://, or no scheme) ConnectionMode::Http } - /// Parse selector strings into HashMap fn parse_selector(selector_list: &[String]) -> HashMap { let mut map = HashMap::new(); for item in selector_list { @@ -382,7 +310,6 @@ impl CliArgs { map } - /// Convert policy string to PolicyConfig fn parse_policy(&self, policy_str: &str) -> PolicyConfig { match policy_str { "random" => PolicyConfig::Random, @@ -395,9 +322,9 @@ impl CliArgs { max_tree_size: self.max_tree_size, }, "power_of_two" => PolicyConfig::PowerOfTwo { - load_check_interval_secs: 5, // Default value + load_check_interval_secs: 5, }, - _ => PolicyConfig::RoundRobin, // Fallback + _ => PolicyConfig::RoundRobin, } } @@ -482,26 +409,21 @@ impl CliArgs { }) } - /// Convert CLI arguments to RouterConfig fn to_router_config( &self, prefill_urls: Vec<(String, Option)>, ) -> ConfigResult { - // Determine routing mode let mode = if self.enable_igw { - // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder RoutingMode::Regular { worker_urls: vec![], } } else if matches!(self.backend, Backend::Openai) { - // OpenAI backend mode - use worker_urls as base(s) RoutingMode::OpenAI { worker_urls: self.worker_urls.clone(), } } else if self.pd_disaggregation { let decode_urls = self.decode.clone(); - // Validate PD configuration if not using service discovery if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) { return Err(ConfigError::ValidationFailed { reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(), @@ -515,7 +437,6 @@ impl CliArgs { decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)), } } else { - // Regular mode if !self.service_discovery && self.worker_urls.is_empty() { return Err(ConfigError::ValidationFailed { reason: "Regular mode requires --worker-urls when not using service discovery" @@ -527,10 +448,8 @@ impl CliArgs { } }; - // Main policy let policy = self.parse_policy(&self.policy); - // Service discovery configuration let discovery = if self.service_discovery { Some(DiscoveryConfig { enabled: true, @@ -546,13 +465,11 @@ impl CliArgs { None }; - // Metrics configuration let metrics = Some(MetricsConfig { port: self.prometheus_port, host: self.prometheus_host.clone(), }); - // Determine connection mode from all worker URLs let mut all_urls = Vec::new(); match &mode { RoutingMode::Regular { worker_urls } => { @@ -568,9 +485,7 @@ impl CliArgs { } all_urls.extend(decode_urls.clone()); } - RoutingMode::OpenAI { .. } => { - // For connection-mode detection, skip URLs; OpenAI forces HTTP below. - } + RoutingMode::OpenAI { .. } => {} } let connection_mode = match &mode { RoutingMode::OpenAI { .. } => ConnectionMode::Http, @@ -589,7 +504,6 @@ impl CliArgs { None }; - // Build RouterConfig Ok(RouterConfig { mode, policy, @@ -612,8 +526,8 @@ impl CliArgs { Some(self.request_id_headers.clone()) }, max_concurrent_requests: self.max_concurrent_requests, - queue_size: 100, // Default queue size - queue_timeout_secs: 60, // Default timeout + queue_size: 100, + queue_timeout_secs: 60, cors_allowed_origins: self.cors_allowed_origins.clone(), retry: RetryConfig { max_retries: self.retry_max_retries, @@ -646,9 +560,7 @@ impl CliArgs { }) } - /// Create ServerConfig from CLI args and RouterConfig fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig { - // Create service discovery config if enabled let service_discovery_config = if self.service_discovery { Some(ServiceDiscoveryConfig { enabled: true, @@ -665,7 +577,6 @@ impl CliArgs { None }; - // Create Prometheus config let prometheus_config = Some(PrometheusConfig { port: self.prometheus_port, host: self.prometheus_host.clone(), @@ -691,19 +602,15 @@ impl CliArgs { } fn main() -> Result<(), Box> { - // Parse prefill arguments manually before clap parsing let prefill_urls = parse_prefill_args(); - // Filter out prefill arguments and their values before passing to clap let mut filtered_args: Vec = Vec::new(); let raw_args: Vec = std::env::args().collect(); let mut i = 0; while i < raw_args.len() { if raw_args[i] == "--prefill" && i + 1 < raw_args.len() { - // Skip --prefill and its URL i += 2; - // Also skip bootstrap port if present if i < raw_args.len() && !raw_args[i].starts_with("--") && (raw_args[i].parse::().is_ok() || raw_args[i].to_lowercase() == "none") @@ -716,10 +623,8 @@ fn main() -> Result<(), Box> { } } - // Parse CLI arguments with clap using filtered args let cli_args = CliArgs::parse_from(filtered_args); - // Print startup info println!("SGLang Router starting..."); println!("Host: {}:{}", cli_args.host, cli_args.port); let mode_str = if cli_args.enable_igw { @@ -733,7 +638,6 @@ fn main() -> Result<(), Box> { }; println!("Mode: {}", mode_str); - // Warn for runtimes that are parsed but not yet implemented match cli_args.backend { Backend::Vllm | Backend::Trtllm | Backend::Anthropic => { println!( @@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.", } } - // Convert to RouterConfig let router_config = cli_args.to_router_config(prefill_urls)?; - - // Validate configuration router_config.validate()?; - - // Create ServerConfig let server_config = cli_args.to_server_config(router_config); - - // Create a new runtime for the server (like Python binding does) let runtime = tokio::runtime::Runtime::new()?; - - // Block on the async startup function runtime.block_on(async move { server::startup(server_config).await })?; Ok(()) diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 7235370fe..2a715c423 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -19,7 +19,6 @@ impl Default for PrometheusConfig { } pub fn init_metrics() { - // Request metrics describe_counter!( "sgl_router_requests_total", "Total number of requests by route and method" @@ -45,7 +44,6 @@ pub fn init_metrics() { "Total number of requests that exhausted retries by route" ); - // Circuit breaker metrics describe_gauge!( "sgl_router_cb_state", "Circuit breaker state per worker (0=closed, 1=open, 2=half_open)" @@ -59,7 +57,6 @@ pub fn init_metrics() { "Total number of circuit breaker outcomes by worker and outcome type (success/failure)" ); - // Worker metrics describe_gauge!( "sgl_router_active_workers", "Number of currently active workers" @@ -74,7 +71,6 @@ pub fn init_metrics() { "Total requests processed by each worker" ); - // Policy metrics describe_counter!( "sgl_router_policy_decisions_total", "Total routing policy decisions by policy and worker" @@ -92,7 +88,6 @@ pub fn init_metrics() { describe_gauge!("sgl_router_max_load", "Maximum worker load"); describe_gauge!("sgl_router_min_load", "Minimum worker load"); - // PD-specific metrics describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route"); describe_counter!( "sgl_router_pd_prefill_requests_total", @@ -123,7 +118,6 @@ pub fn init_metrics() { "PD request duration by route" ); - // Service discovery metrics describe_counter!( "sgl_router_discovery_updates_total", "Total service discovery update events" @@ -137,13 +131,11 @@ pub fn init_metrics() { "Number of workers removed in last discovery update" ); - // Generate request specific metrics describe_histogram!( "sgl_router_generate_duration_seconds", "Generate request duration" ); - // Embedding request specific metrics describe_counter!("sgl_router_embeddings_total", "Total embedding requests"); describe_histogram!( "sgl_router_embeddings_duration_seconds", @@ -155,13 +147,11 @@ pub fn init_metrics() { ); describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size"); - // Running requests gauge for cache-aware policy describe_gauge!( "sgl_router_running_requests", "Number of running requests per worker" ); - // Tokenizer metrics describe_histogram!( "sgl_tokenizer_encode_duration_seconds", "Time to encode text to tokens" @@ -207,7 +197,6 @@ pub fn init_metrics() { "Vocabulary size of the loaded tokenizer" ); - // Stop sequence detection metrics describe_counter!( "sgl_tokenizer_stop_sequences_detected_total", "Total stop sequences detected by type" @@ -221,7 +210,6 @@ pub fn init_metrics() { "Time to check for stop sequences per token" ); - // Streaming decode metrics describe_counter!( "sgl_tokenizer_stream_tokens_total", "Total tokens processed in streaming decode" @@ -235,7 +223,6 @@ pub fn init_metrics() { "Time per streaming decode step" ); - // Factory metrics describe_counter!( "sgl_tokenizer_factory_loads_total", "Total tokenizer loads by file type" @@ -251,7 +238,6 @@ pub fn init_metrics() { } pub fn start_prometheus(config: PrometheusConfig) { - // Initialize metric descriptions init_metrics(); let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); @@ -280,7 +266,6 @@ pub struct RouterMetrics; pub struct TokenizerMetrics; impl RouterMetrics { - // Request metrics pub fn record_request(route: &str) { counter!("sgl_router_requests_total", "route" => route.to_string() @@ -324,7 +309,6 @@ impl RouterMetrics { .increment(1); } - // Worker metrics pub fn set_active_workers(count: usize) { gauge!("sgl_router_active_workers").set(count as f64); } @@ -350,7 +334,6 @@ impl RouterMetrics { .increment(1); } - // Policy metrics pub fn record_policy_decision(policy: &str, worker: &str) { counter!("sgl_router_policy_decisions_total", "policy" => policy.to_string(), @@ -383,7 +366,6 @@ impl RouterMetrics { gauge!("sgl_router_min_load").set(min_load as f64); } - // PD-specific metrics pub fn record_pd_request(route: &str) { counter!("sgl_router_pd_requests_total", "route" => route.to_string() @@ -440,19 +422,16 @@ impl RouterMetrics { .increment(1); } - // Service discovery metrics pub fn record_discovery_update(added: usize, removed: usize) { counter!("sgl_router_discovery_updates_total").increment(1); gauge!("sgl_router_discovery_workers_added").set(added as f64); gauge!("sgl_router_discovery_workers_removed").set(removed as f64); } - // Generate request metrics pub fn record_generate_duration(duration: Duration) { histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64()); } - // Embeddings metrics pub fn record_embeddings_request() { counter!("sgl_router_embeddings_total").increment(1); } @@ -473,7 +452,6 @@ impl RouterMetrics { gauge!("sgl_router_embeddings_queue_size").set(size as f64); } - // Running requests for cache-aware policy pub fn set_running_requests(worker: &str, count: usize) { gauge!("sgl_router_running_requests", "worker" => worker.to_string() @@ -481,7 +459,6 @@ impl RouterMetrics { .set(count as f64); } - // Circuit breaker metrics pub fn set_cb_state(worker: &str, state_code: u8) { gauge!("sgl_router_cb_state", "worker" => worker.to_string() @@ -508,7 +485,6 @@ impl RouterMetrics { } impl TokenizerMetrics { - // Encoding metrics pub fn record_encode_request(tokenizer_type: &str) { counter!("sgl_tokenizer_encode_requests_total", "tokenizer_type" => tokenizer_type.to_string() @@ -535,7 +511,6 @@ impl TokenizerMetrics { histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64); } - // Decoding metrics pub fn record_decode_request(tokenizer_type: &str) { counter!("sgl_tokenizer_decode_requests_total", "tokenizer_type" => tokenizer_type.to_string() @@ -558,7 +533,6 @@ impl TokenizerMetrics { histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64); } - // Batch encoding metrics pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) { histogram!("sgl_tokenizer_encode_batch_duration_seconds", "batch_size" => batch_size.to_string() @@ -566,7 +540,6 @@ impl TokenizerMetrics { .record(duration.as_secs_f64()); } - // Stop sequence detection metrics pub fn record_stop_sequence_detected(stop_type: &str) { counter!("sgl_tokenizer_stop_sequences_detected_total", "type" => stop_type.to_string() @@ -582,7 +555,6 @@ impl TokenizerMetrics { histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64()); } - // Streaming decode metrics pub fn record_stream_token() { counter!("sgl_tokenizer_stream_tokens_total").increment(1); } @@ -595,7 +567,6 @@ impl TokenizerMetrics { histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64()); } - // Factory metrics pub fn record_factory_load(file_type: &str) { counter!("sgl_tokenizer_factory_loads_total", "file_type" => file_type.to_string() @@ -614,7 +585,6 @@ impl TokenizerMetrics { histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64()); } - // Vocabulary metrics pub fn set_vocab_size(tokenizer_type: &str, size: usize) { gauge!("sgl_tokenizer_vocab_size", "tokenizer_type" => tokenizer_type.to_string() @@ -705,7 +675,6 @@ mod tests { .parse() .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); - // Should fall back to 0.0.0.0 assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); } } @@ -780,7 +749,6 @@ mod tests { fn test_duration_suffix_matcher() { let matcher = Matcher::Suffix(String::from("duration_seconds")); - // Test matching behavior let _matching_metrics = [ "request_duration_seconds", "response_duration_seconds", @@ -789,8 +757,6 @@ mod tests { let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"]; - // Note: We can't directly test Matcher matching without the internals, - // but we can verify the matcher is created correctly match matcher { Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"), _ => panic!("Expected Suffix matcher"), @@ -801,7 +767,6 @@ mod tests { #[test] fn test_prometheus_builder_configuration() { - // This test verifies the builder configuration without actually starting Prometheus let _config = PrometheusConfig::default(); let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); @@ -810,10 +775,8 @@ mod tests { 60.0, 90.0, 120.0, 180.0, 240.0, ]; - // Verify bucket configuration assert_eq!(duration_bucket.len(), 20); - // Verify matcher is suffix type match duration_matcher { Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"), _ => panic!("Expected Suffix matcher"), @@ -832,14 +795,12 @@ mod tests { #[test] fn test_custom_buckets_for_different_metrics() { - // Test that we can create different bucket configurations let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0]; let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0]; assert_eq!(request_buckets.len(), 5); assert_eq!(generate_buckets.len(), 6); - // Verify each set is sorted for i in 1..request_buckets.len() { assert!(request_buckets[i] > request_buckets[i - 1]); } @@ -853,7 +814,6 @@ mod tests { #[test] fn test_metrics_static_methods() { - // Test that all static methods can be called without panic RouterMetrics::record_request("/generate"); RouterMetrics::record_request_duration("/generate", Duration::from_millis(100)); RouterMetrics::record_request_error("/generate", "timeout"); @@ -887,41 +847,32 @@ mod tests { #[test] fn test_tokenizer_metrics_static_methods() { - // Test that all tokenizer metric methods can be called without panic - - // Encoding metrics TokenizerMetrics::record_encode_request("huggingface"); TokenizerMetrics::record_encode_duration(Duration::from_millis(10)); TokenizerMetrics::record_encode_error("invalid_input"); TokenizerMetrics::record_tokens_per_encode(100); TokenizerMetrics::record_chars_per_encode(500); - // Decoding metrics TokenizerMetrics::record_decode_request("huggingface"); TokenizerMetrics::record_decode_duration(Duration::from_millis(5)); TokenizerMetrics::record_decode_error("invalid_tokens"); TokenizerMetrics::record_tokens_per_decode(50); - // Batch encoding TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10); - // Stop sequence detection TokenizerMetrics::record_stop_sequence_detected("token"); TokenizerMetrics::record_stop_sequence_detected("string"); TokenizerMetrics::record_partial_match(); TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100)); - // Streaming decode TokenizerMetrics::record_stream_token(); TokenizerMetrics::record_incomplete_utf8(); TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50)); - // Factory metrics TokenizerMetrics::record_factory_load("json"); TokenizerMetrics::record_factory_error("unsupported_format"); TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200)); - // Vocabulary metrics TokenizerMetrics::set_vocab_size("huggingface", 50000); } @@ -929,17 +880,14 @@ mod tests { #[test] fn test_port_already_in_use() { - // Skip this test if we can't bind to the port - let port = 29123; // Use a different port to avoid conflicts + let port = 29123; if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) { - // Port is available, we can test let config = PrometheusConfig { port, host: "127.0.0.1".to_string(), }; - // Just verify config is created correctly assert_eq!(config.port, port); } } @@ -948,8 +896,6 @@ mod tests { #[test] fn test_metrics_endpoint_accessibility() { - // This would be an integration test in practice - // Here we just verify the configuration let config = PrometheusConfig { port: 29000, host: "127.0.0.1".to_string(), @@ -963,7 +909,6 @@ mod tests { #[test] fn test_concurrent_metric_updates() { - // Test that metric updates can be called concurrently use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; @@ -984,11 +929,9 @@ mod tests { handles.push(handle); } - // Let threads run briefly thread::sleep(Duration::from_millis(10)); done.store(true, Ordering::Relaxed); - // Wait for all threads for handle in handles { handle.join().unwrap(); } @@ -998,7 +941,6 @@ mod tests { #[test] fn test_empty_string_metrics() { - // Test that empty strings don't cause issues RouterMetrics::record_request(""); RouterMetrics::set_worker_health("", true); RouterMetrics::record_policy_decision("", ""); @@ -1030,7 +972,6 @@ mod tests { #[test] fn test_extreme_metric_values() { - // Test extreme values RouterMetrics::set_active_workers(0); RouterMetrics::set_active_workers(usize::MAX); @@ -1038,7 +979,6 @@ mod tests { RouterMetrics::set_worker_load("worker", usize::MAX); RouterMetrics::record_request_duration("route", Duration::from_nanos(1)); - // 24 hours RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); } } diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 734f24fde..76dd25429 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -19,7 +19,6 @@ use tokio::task; use tokio::time; use tracing::{debug, error, info, warn}; -/// Represents the service discovery configuration #[derive(Debug, Clone)] pub struct ServiceDiscoveryConfig { pub enabled: bool, @@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig { enabled: false, selector: HashMap::new(), check_interval: Duration::from_secs(60), - port: 8000, // Standard port for modern services - namespace: None, // None means watch all namespaces + port: 8000, + namespace: None, pd_mode: false, prefill_selector: HashMap::new(), decode_selector: HashMap::new(), @@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig { } } -/// Pod type for PD mode service discovery #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum PodType { Prefill, @@ -59,7 +57,6 @@ pub enum PodType { Regular, } -/// Represents a Kubernetes pod's information used for worker management #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PodInfo { pub name: String, @@ -71,7 +68,6 @@ pub struct PodInfo { } impl PodInfo { - /// Check if a pod matches any of the given selectors fn matches_selector(pod: &Pod, selector: &HashMap) -> bool { if selector.is_empty() { return false; @@ -83,19 +79,15 @@ impl PodInfo { .is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v))) } - /// Check if a pod should be included in service discovery pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool { if config.pd_mode { - // In PD mode, at least one selector must be non-empty if config.prefill_selector.is_empty() && config.decode_selector.is_empty() { warn!("PD mode enabled but both prefill_selector and decode_selector are empty"); return false; } - // In PD mode, pod must match either prefill or decode selector Self::matches_selector(pod, &config.prefill_selector) || Self::matches_selector(pod, &config.decode_selector) } else { - // In regular mode, pod must match the general selector if config.selector.is_empty() { warn!("Regular mode enabled but selector is empty"); return false; @@ -104,7 +96,6 @@ impl PodInfo { } } - /// Unified PodInfo creation with optional PD configuration pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option { let name = pod.metadata.name.clone()?; let status = pod.status.clone()?; @@ -120,10 +111,8 @@ impl PodInfo { let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string()); - // Determine pod type based on labels if config is provided and in PD mode let pod_type = if let Some(config) = config { if config.pd_mode { - // Use simplified helper methods for cleaner logic if Self::matches_selector(pod, &config.prefill_selector) { Some(PodType::Prefill) } else if Self::matches_selector(pod, &config.decode_selector) { @@ -135,11 +124,9 @@ impl PodInfo { Some(PodType::Regular) } } else { - // No config provided, default to None (for backwards compatibility) None }; - // Extract bootstrap port from annotations for prefill pods let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) { if let Some(config) = config { pod.metadata @@ -164,12 +151,10 @@ impl PodInfo { }) } - /// Returns true if the pod is in a state where it can accept traffic pub fn is_healthy(&self) -> bool { self.is_ready && self.status == "Running" } - /// Generates a worker URL for this pod pub fn worker_url(&self, port: u16) -> String { format!("http://{}:{}", self.ip, port) } @@ -179,9 +164,7 @@ pub async fn start_service_discovery( config: ServiceDiscoveryConfig, app_context: Arc, ) -> Result, kube::Error> { - // Don't initialize anything if service discovery is disabled if !config.enabled { - // Return a generic error when service discovery is disabled return Err(kube::Error::Api(kube::error::ErrorResponse { status: "Disabled".to_string(), message: "Service discovery is disabled".to_string(), @@ -192,7 +175,6 @@ pub async fn start_service_discovery( let _ = rustls::crypto::ring::default_provider().install_default(); - // Initialize Kubernetes client let client = Client::try_default().await?; // Log the appropriate selectors based on mode @@ -229,12 +211,9 @@ pub async fn start_service_discovery( ); } - // Create the task that will run in the background let handle = task::spawn(async move { - // We'll track pods we've already added to avoid duplicates let tracked_pods = Arc::new(Mutex::new(HashSet::new())); - // Create a watcher for pods let pods: Api = if let Some(namespace) = &config.namespace { Api::namespaced(client, namespace) } else { @@ -243,23 +222,19 @@ pub async fn start_service_discovery( debug!("K8s service discovery initialized"); - // Create Arcs for configuration data let config_arc = Arc::new(config.clone()); let port = config.port; let mut retry_delay = Duration::from_secs(1); - const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max + const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); loop { - // Create a watcher with the proper parameters according to the kube-rs API let watcher_config = Config::default(); let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects(); - // Clone Arcs for the closures let config_clone = Arc::clone(&config_arc); let tracked_pods_clone = Arc::clone(&tracked_pods); - // Simplified label selector filter using helper method let filtered_stream = watcher_stream.filter_map(move |obj_res| { let config_inner = Arc::clone(&config_clone); @@ -277,7 +252,6 @@ pub async fn start_service_discovery( } }); - // Clone again for the next closure let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone); let app_context_clone = Arc::clone(&app_context); let config_clone2 = Arc::clone(&config_arc); @@ -317,7 +291,6 @@ pub async fn start_service_discovery( .await { Ok(_) => { - // Reset retry delay on success retry_delay = Duration::from_secs(1); } Err(err) => { @@ -328,12 +301,10 @@ pub async fn start_service_discovery( ); time::sleep(retry_delay).await; - // Exponential backoff with jitter retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); } } - // If the watcher exits for some reason, wait a bit before restarting warn!( "Kubernetes watcher exited, restarting in {} seconds", config_arc.check_interval.as_secs() @@ -354,9 +325,7 @@ async fn handle_pod_event( ) { let worker_url = pod_info.worker_url(port); - // If pod is healthy, try to add it (with atomic check-and-insert) if pod_info.is_healthy() { - // Atomic check-and-insert to prevent race conditions let should_add = { let mut tracker = match tracked_pods.lock() { Ok(tracker) => tracker, @@ -367,9 +336,8 @@ async fn handle_pod_event( }; if tracker.contains(pod_info) { - false // Already tracked + false } else { - // Reserve the spot to prevent other threads from adding the same pod tracker.insert(pod_info.clone()); true } @@ -381,7 +349,6 @@ async fn handle_pod_event( pod_info.name, pod_info.pod_type, worker_url ); - // Build worker config based on pod type and routing mode let worker_type = if pd_mode { match &pod_info.pod_type { Some(PodType::Prefill) => Some("prefill".to_string()), @@ -392,7 +359,6 @@ async fn handle_pod_event( None }; - // Only set bootstrap_port for prefill workers in PD mode let bootstrap_port = if pd_mode { match &pod_info.pod_type { Some(PodType::Prefill) => pod_info.bootstrap_port, @@ -425,7 +391,6 @@ async fn handle_pod_event( } Err(e) => { error!("Failed to add worker {} to router: {}", worker_url, e); - // Remove from tracking since addition failed if let Ok(mut tracker) = tracked_pods.lock() { tracker.remove(pod_info); } @@ -464,8 +429,6 @@ async fn handle_pod_deletion( error!("Failed to remove worker {}: {}", worker_url, e); } } else { - // This case might occur if a pod is deleted before it was ever marked healthy and added. - // Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added). debug!( "Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}", pod_info.name, pod_info.pod_type, worker_url @@ -480,7 +443,6 @@ mod tests { use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; - // Helper function to create a Pod for testing PodInfo::from_pod fn create_k8s_pod( name: Option<&str>, ip: Option<&str>, @@ -523,7 +485,6 @@ mod tests { pod } - // Helper function to create a Pod with PD-specific labels and annotations fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option) -> Pod { let mut labels = std::collections::BTreeMap::new(); labels.insert("app".to_string(), "sglang".to_string()); @@ -559,18 +520,15 @@ mod tests { } } - // Helper to create an AppContext instance for testing event handlers async fn create_test_app_context() -> Arc { use crate::config::RouterConfig; use crate::middleware::TokenBucket; - // Create a minimal RouterConfig for testing with very short timeout let router_config = RouterConfig { worker_startup_timeout_secs: 1, ..Default::default() - }; // Very short timeout for tests + }; - // Create AppContext with minimal components Arc::new(AppContext { client: reqwest::Client::new(), router_config: router_config.clone(), @@ -579,16 +537,15 @@ mod tests { policy_registry: Arc::new(crate::policies::PolicyRegistry::new( router_config.policy.clone(), )), - tokenizer: None, // HTTP mode doesn't need tokenizer - reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser - tool_parser_registry: None, // HTTP mode doesn't need tool parser - router_manager: None, // Test doesn't need router manager + tokenizer: None, + reasoning_parser_factory: None, + tool_parser_registry: None, + router_manager: None, response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), load_monitor: None, }) } - // Helper to create a PD config for testing fn create_pd_config() -> ServiceDiscoveryConfig { let mut prefill_selector = HashMap::new(); prefill_selector.insert("app".to_string(), "sglang".to_string()); @@ -615,19 +572,15 @@ mod tests { fn test_pod_info_should_include() { let config = create_pd_config(); - // Test prefill pod should be included let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081)); assert!(PodInfo::should_include(&prefill_pod, &config)); - // Test decode pod should be included let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None); assert!(PodInfo::should_include(&decode_pod, &config)); - // Test unmatched pod should not be included let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None); assert!(!PodInfo::should_include(&unmatched_pod, &config)); - // Test regular mode let mut regular_config = ServiceDiscoveryConfig::default(); regular_config .selector @@ -654,7 +607,6 @@ mod tests { #[test] fn test_pod_type_enum() { - // Test that PodType enum has expected variants let prefill = PodType::Prefill; let decode = PodType::Decode; let regular = PodType::Regular; @@ -714,7 +666,7 @@ mod tests { fn test_pod_info_from_pod_with_pd_config_regular_mode() { let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None); let mut config = create_pd_config(); - config.pd_mode = false; // Set to regular mode + config.pd_mode = false; let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); assert_eq!(pod_info.name, "regular-pod"); @@ -742,7 +694,6 @@ mod tests { #[test] fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() { let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None); - // Add invalid bootstrap port annotation pod.metadata.annotations.as_mut().unwrap().insert( "sglang.ai/bootstrap-port".to_string(), "invalid".to_string(), @@ -751,7 +702,7 @@ mod tests { let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap(); assert_eq!(pod_info.pod_type, Some(PodType::Prefill)); - assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port + assert!(pod_info.bootstrap_port.is_none()); } #[test] @@ -1077,7 +1028,6 @@ mod tests { ) .await; - // Pod should not be tracked since add_worker_from_url will fail for non-running server assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); }