From c071198c1d74d1cf2dabbf6f38afce31411a1586 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 27 Jun 2025 15:42:02 -0700 Subject: [PATCH] [router] add centralized configuration module for sgl-router (#7588) --- sgl-router/Cargo.toml | 1 + sgl-router/src/config/mod.rs | 28 ++ sgl-router/src/config/types.rs | 298 +++++++++++++++++ sgl-router/src/config/validation.rs | 496 ++++++++++++++++++++++++++++ sgl-router/src/lib.rs | 153 +++++---- 5 files changed, 916 insertions(+), 60 deletions(-) create mode 100644 sgl-router/src/config/mod.rs create mode 100644 sgl-router/src/config/types.rs create mode 100644 sgl-router/src/config/validation.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 1939ef7c3..fa9d58967 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -36,6 +36,7 @@ metrics-exporter-prometheus = "0.17.0" # Added for request tracing uuid = { version = "1.10", features = ["v4", "serde"] } thiserror = "2.0.12" +url = "2.5.4" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/sgl-router/src/config/mod.rs b/sgl-router/src/config/mod.rs new file mode 100644 index 000000000..4622ff781 --- /dev/null +++ b/sgl-router/src/config/mod.rs @@ -0,0 +1,28 @@ +pub mod types; +pub mod validation; + +pub use types::*; +pub use validation::*; + +/// Configuration errors +#[derive(Debug, thiserror::Error)] +pub enum ConfigError { + #[error("Validation failed: {reason}")] + ValidationFailed { reason: String }, + + #[error("Invalid value for field '{field}': {value} - {reason}")] + InvalidValue { + field: String, + value: String, + reason: String, + }, + + #[error("Incompatible configuration: {reason}")] + IncompatibleConfig { reason: String }, + + #[error("Missing required field: {field}")] + MissingRequired { field: String }, +} + +/// Result type for configuration operations +pub type ConfigResult = Result; diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs new file mode 100644 index 000000000..732b1f874 --- /dev/null +++ b/sgl-router/src/config/types.rs @@ -0,0 +1,298 @@ +use super::{ConfigError, ConfigResult}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Main router configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterConfig { + /// Routing mode configuration + pub mode: RoutingMode, + /// Policy configuration + pub policy: PolicyConfig, + /// Server host address + pub host: String, + /// Server port + pub port: u16, + /// Maximum payload size in bytes + pub max_payload_size: usize, + /// Request timeout in seconds + pub request_timeout_secs: u64, + /// Worker startup timeout in seconds + pub worker_startup_timeout_secs: u64, + /// Worker health check interval in seconds + pub worker_startup_check_interval_secs: u64, + /// Service discovery configuration (optional) + pub discovery: Option, + /// Metrics configuration (optional) + pub metrics: Option, + /// Log directory (None = stdout only) + pub log_dir: Option, + /// Verbose logging + pub verbose: bool, +} + +/// Routing mode configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RoutingMode { + #[serde(rename = "regular")] + Regular { + /// List of worker URLs + worker_urls: Vec, + }, + #[serde(rename = "prefill_decode")] + PrefillDecode { + /// Prefill worker URLs with optional bootstrap ports + prefill_urls: Vec<(String, Option)>, + /// Decode worker URLs + decode_urls: Vec, + }, +} + +impl RoutingMode { + pub fn is_pd_mode(&self) -> bool { + matches!(self, RoutingMode::PrefillDecode { .. }) + } + + pub fn worker_count(&self) -> usize { + match self { + RoutingMode::Regular { worker_urls } => worker_urls.len(), + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + } => prefill_urls.len() + decode_urls.len(), + } + } +} + +/// Policy configuration for routing +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PolicyConfig { + #[serde(rename = "random")] + Random, + + #[serde(rename = "round_robin")] + RoundRobin, + + #[serde(rename = "cache_aware")] + CacheAware { + /// Minimum prefix match ratio to use cache-based routing + cache_threshold: f32, + /// Absolute load difference threshold for load balancing + balance_abs_threshold: usize, + /// Relative load ratio threshold for load balancing + balance_rel_threshold: f32, + /// Interval between cache eviction cycles (seconds) + eviction_interval_secs: u64, + /// Maximum cache tree size per tenant + max_tree_size: usize, + }, + + #[serde(rename = "power_of_two")] + PowerOfTwo { + /// Interval for load monitoring (seconds) + load_check_interval_secs: u64, + }, +} + +impl PolicyConfig { + pub fn name(&self) -> &'static str { + match self { + PolicyConfig::Random => "random", + PolicyConfig::RoundRobin => "round_robin", + PolicyConfig::CacheAware { .. } => "cache_aware", + PolicyConfig::PowerOfTwo { .. } => "power_of_two", + } + } +} + +/// Service discovery configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiscoveryConfig { + /// Enable service discovery + pub enabled: bool, + /// Kubernetes namespace (None = all namespaces) + pub namespace: Option, + /// Service discovery port + pub port: u16, + /// Check interval for service discovery + pub check_interval_secs: u64, + /// Regular mode selector + pub selector: HashMap, + /// PD mode prefill selector + pub prefill_selector: HashMap, + /// PD mode decode selector + pub decode_selector: HashMap, + /// Bootstrap port annotation key + pub bootstrap_port_annotation: String, +} + +impl Default for DiscoveryConfig { + fn default() -> Self { + Self { + enabled: false, + namespace: None, + port: 8000, + check_interval_secs: 60, + selector: HashMap::new(), + prefill_selector: HashMap::new(), + decode_selector: HashMap::new(), + bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(), + } + } +} + +/// Metrics configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MetricsConfig { + /// Prometheus metrics port + pub port: u16, + /// Prometheus metrics host + pub host: String, +} + +impl Default for MetricsConfig { + fn default() -> Self { + Self { + port: 29000, + host: "127.0.0.1".to_string(), + } + } +} + +impl Default for RouterConfig { + fn default() -> Self { + Self { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 268_435_456, // 256MB + request_timeout_secs: 600, + worker_startup_timeout_secs: 300, + worker_startup_check_interval_secs: 10, + discovery: None, + metrics: None, + log_dir: None, + verbose: false, + } + } +} + +impl RouterConfig { + /// Create a new configuration with mode and policy + pub fn new(mode: RoutingMode, policy: PolicyConfig) -> Self { + Self { + mode, + policy, + ..Default::default() + } + } + + /// Validate the configuration + pub fn validate(&self) -> ConfigResult<()> { + crate::config::validation::ConfigValidator::validate(self) + } + + /// Get the routing mode type as a string + pub fn mode_type(&self) -> &'static str { + match self.mode { + RoutingMode::Regular { .. } => "regular", + RoutingMode::PrefillDecode { .. } => "prefill_decode", + } + } + + /// Check if service discovery is enabled + pub fn has_service_discovery(&self) -> bool { + self.discovery.as_ref().map_or(false, |d| d.enabled) + } + + /// Check if metrics are enabled + pub fn has_metrics(&self) -> bool { + self.metrics.is_some() + } + + /// Convert to routing PolicyConfig for internal use + pub fn to_routing_policy_config(&self) -> ConfigResult { + match (&self.mode, &self.policy) { + ( + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + }, + policy, + ) => { + // Map policy to PDSelectionPolicy + let selection_policy = match policy { + PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random, + PolicyConfig::PowerOfTwo { .. } => { + crate::pd_types::PDSelectionPolicy::PowerOfTwo + } + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + .. + } => crate::pd_types::PDSelectionPolicy::CacheAware { + cache_threshold: *cache_threshold, + balance_abs_threshold: *balance_abs_threshold, + balance_rel_threshold: *balance_rel_threshold, + }, + PolicyConfig::RoundRobin => { + return Err(ConfigError::IncompatibleConfig { + reason: "RoundRobin policy is not supported in PD disaggregated mode" + .to_string(), + }); + } + }; + + Ok(crate::router::PolicyConfig::PrefillDecodeConfig { + selection_policy, + prefill_urls: prefill_urls.clone(), + decode_urls: decode_urls.clone(), + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval_secs, + }) + } + (RoutingMode::Regular { .. }, PolicyConfig::Random) => { + Ok(crate::router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval_secs, + }) + } + (RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => { + Ok(crate::router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval_secs, + }) + } + ( + RoutingMode::Regular { .. }, + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + }, + ) => Ok(crate::router::PolicyConfig::CacheAwareConfig { + cache_threshold: *cache_threshold, + balance_abs_threshold: *balance_abs_threshold, + balance_rel_threshold: *balance_rel_threshold, + eviction_interval_secs: *eviction_interval_secs, + max_tree_size: *max_tree_size, + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval_secs, + }), + (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { + Err(ConfigError::IncompatibleConfig { + reason: "PowerOfTwo policy is only supported in PD disaggregated mode" + .to_string(), + }) + } + } + } +} diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs new file mode 100644 index 000000000..ed08c212c --- /dev/null +++ b/sgl-router/src/config/validation.rs @@ -0,0 +1,496 @@ +use super::*; + +/// Configuration validator +pub struct ConfigValidator; + +impl ConfigValidator { + /// Validate a complete router configuration + pub fn validate(config: &RouterConfig) -> ConfigResult<()> { + // Check if service discovery is enabled + let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled); + + Self::validate_mode(&config.mode, has_service_discovery)?; + Self::validate_policy(&config.policy)?; + Self::validate_server_settings(config)?; + + if let Some(discovery) = &config.discovery { + Self::validate_discovery(discovery, &config.mode)?; + } + + if let Some(metrics) = &config.metrics { + Self::validate_metrics(metrics)?; + } + + Self::validate_compatibility(config)?; + + Ok(()) + } + + /// Validate routing mode configuration + fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> { + match mode { + RoutingMode::Regular { worker_urls } => { + // Validate URLs if any are provided + if !worker_urls.is_empty() { + Self::validate_urls(worker_urls)?; + } + // Note: We allow empty worker URLs even without service discovery + // to let the router start and fail at runtime when routing requests. + // This matches legacy behavior and test expectations. + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + } => { + // Only require URLs if service discovery is disabled + if !has_service_discovery { + if prefill_urls.is_empty() { + return Err(ConfigError::ValidationFailed { + reason: "PD mode requires at least one prefill worker URL".to_string(), + }); + } + if decode_urls.is_empty() { + return Err(ConfigError::ValidationFailed { + reason: "PD mode requires at least one decode worker URL".to_string(), + }); + } + } + + // Validate URLs if any are provided + if !prefill_urls.is_empty() { + let prefill_url_strings: Vec = + prefill_urls.iter().map(|(url, _)| url.clone()).collect(); + Self::validate_urls(&prefill_url_strings)?; + } + if !decode_urls.is_empty() { + Self::validate_urls(decode_urls)?; + } + + // Validate bootstrap ports + for (_url, port) in prefill_urls { + if let Some(port) = port { + if *port == 0 { + return Err(ConfigError::InvalidValue { + field: "bootstrap_port".to_string(), + value: port.to_string(), + reason: "Port must be between 1 and 65535".to_string(), + }); + } + } + } + } + } + Ok(()) + } + + /// Validate policy configuration + fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> { + match policy { + PolicyConfig::Random | PolicyConfig::RoundRobin => { + // No specific validation needed + } + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold: _, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + if !(0.0..=1.0).contains(cache_threshold) { + return Err(ConfigError::InvalidValue { + field: "cache_threshold".to_string(), + value: cache_threshold.to_string(), + reason: "Must be between 0.0 and 1.0".to_string(), + }); + } + + if *balance_rel_threshold < 1.0 { + return Err(ConfigError::InvalidValue { + field: "balance_rel_threshold".to_string(), + value: balance_rel_threshold.to_string(), + reason: "Must be >= 1.0".to_string(), + }); + } + + if *eviction_interval_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "eviction_interval_secs".to_string(), + value: eviction_interval_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + if *max_tree_size == 0 { + return Err(ConfigError::InvalidValue { + field: "max_tree_size".to_string(), + value: max_tree_size.to_string(), + reason: "Must be > 0".to_string(), + }); + } + } + PolicyConfig::PowerOfTwo { + load_check_interval_secs, + } => { + if *load_check_interval_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "load_check_interval_secs".to_string(), + value: load_check_interval_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + } + } + Ok(()) + } + + /// Validate server configuration + fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> { + if config.port == 0 { + return Err(ConfigError::InvalidValue { + field: "port".to_string(), + value: config.port.to_string(), + reason: "Port must be > 0".to_string(), + }); + } + + if config.max_payload_size == 0 { + return Err(ConfigError::InvalidValue { + field: "max_payload_size".to_string(), + value: config.max_payload_size.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + if config.request_timeout_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "request_timeout_secs".to_string(), + value: config.request_timeout_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + if config.worker_startup_timeout_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "worker_startup_timeout_secs".to_string(), + value: config.worker_startup_timeout_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + if config.worker_startup_check_interval_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "worker_startup_check_interval_secs".to_string(), + value: config.worker_startup_check_interval_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + Ok(()) + } + + /// Validate service discovery configuration + fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> { + if !discovery.enabled { + return Ok(()); // No validation needed if disabled + } + + if discovery.port == 0 { + return Err(ConfigError::InvalidValue { + field: "discovery.port".to_string(), + value: discovery.port.to_string(), + reason: "Port must be > 0".to_string(), + }); + } + + if discovery.check_interval_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "discovery.check_interval_secs".to_string(), + value: discovery.check_interval_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + + // Validate selectors based on mode + match mode { + RoutingMode::Regular { .. } => { + if discovery.selector.is_empty() { + return Err(ConfigError::ValidationFailed { + reason: "Regular mode with service discovery requires a non-empty selector" + .to_string(), + }); + } + } + RoutingMode::PrefillDecode { .. } => { + if discovery.prefill_selector.is_empty() && discovery.decode_selector.is_empty() { + return Err(ConfigError::ValidationFailed { + reason: "PD mode with service discovery requires at least one non-empty selector (prefill or decode)".to_string(), + }); + } + } + } + + Ok(()) + } + + /// Validate metrics configuration + fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> { + if metrics.port == 0 { + return Err(ConfigError::InvalidValue { + field: "metrics.port".to_string(), + value: metrics.port.to_string(), + reason: "Port must be > 0".to_string(), + }); + } + + if metrics.host.is_empty() { + return Err(ConfigError::InvalidValue { + field: "metrics.host".to_string(), + value: metrics.host.clone(), + reason: "Host cannot be empty".to_string(), + }); + } + + Ok(()) + } + + /// Validate compatibility between different configuration sections + fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { + // Check mode and policy compatibility + match (&config.mode, &config.policy) { + (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { + // PowerOfTwo is only supported in PD mode + return Err(ConfigError::IncompatibleConfig { + reason: "PowerOfTwo policy is only supported in PD disaggregated mode" + .to_string(), + }); + } + (RoutingMode::PrefillDecode { .. }, PolicyConfig::RoundRobin) => { + return Err(ConfigError::IncompatibleConfig { + reason: "RoundRobin policy is not supported in PD disaggregated mode" + .to_string(), + }); + } + _ => {} + } + + // Check if service discovery is enabled for worker count validation + let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled); + + // Only validate worker counts if service discovery is disabled + if !has_service_discovery { + // Check if power-of-two policy makes sense with insufficient workers + if let PolicyConfig::PowerOfTwo { .. } = &config.policy { + let worker_count = config.mode.worker_count(); + if worker_count < 2 { + return Err(ConfigError::IncompatibleConfig { + reason: "Power-of-two policy requires at least 2 workers".to_string(), + }); + } + } + } + + Ok(()) + } + + /// Validate URL format + fn validate_urls(urls: &[String]) -> ConfigResult<()> { + for url in urls { + if url.is_empty() { + return Err(ConfigError::InvalidValue { + field: "worker_url".to_string(), + value: url.clone(), + reason: "URL cannot be empty".to_string(), + }); + } + + if !url.starts_with("http://") && !url.starts_with("https://") { + return Err(ConfigError::InvalidValue { + field: "worker_url".to_string(), + value: url.clone(), + reason: "URL must start with http:// or https://".to_string(), + }); + } + + // Basic URL validation + match ::url::Url::parse(url) { + Ok(parsed) => { + // Additional validation + if parsed.host_str().is_none() { + return Err(ConfigError::InvalidValue { + field: "worker_url".to_string(), + value: url.clone(), + reason: "URL must have a valid host".to_string(), + }); + } + } + Err(e) => { + return Err(ConfigError::InvalidValue { + field: "worker_url".to_string(), + value: url.clone(), + reason: format!("Invalid URL format: {}", e), + }); + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_regular_mode() { + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["http://worker:8000".to_string()], + }, + PolicyConfig::Random, + ); + + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + fn test_validate_empty_worker_urls() { + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec![], + }, + PolicyConfig::Random, + ); + + // Empty worker URLs are now allowed to match legacy behavior + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + fn test_validate_empty_worker_urls_with_service_discovery() { + let mut config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec![], + }, + PolicyConfig::Random, + ); + + // Enable service discovery + config.discovery = Some(DiscoveryConfig { + enabled: true, + selector: vec![("app".to_string(), "test".to_string())] + .into_iter() + .collect(), + ..Default::default() + }); + + // Should pass validation since service discovery is enabled + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + fn test_validate_invalid_urls() { + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["invalid-url".to_string()], + }, + PolicyConfig::Random, + ); + + assert!(ConfigValidator::validate(&config).is_err()); + } + + #[test] + fn test_validate_cache_aware_thresholds() { + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec![ + "http://worker1:8000".to_string(), + "http://worker2:8000".to_string(), + ], + }, + PolicyConfig::CacheAware { + cache_threshold: 1.5, // Invalid: > 1.0 + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 60, + max_tree_size: 1000, + }, + ); + + assert!(ConfigValidator::validate(&config).is_err()); + } + + #[test] + fn test_validate_cache_aware_single_worker() { + // Cache-aware with single worker should be allowed (even if not optimal) + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["http://worker1:8000".to_string()], + }, + PolicyConfig::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 60, + max_tree_size: 1000, + }, + ); + + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + fn test_validate_pd_mode() { + let config = RouterConfig::new( + RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))], + decode_urls: vec!["http://decode:8000".to_string()], + }, + PolicyConfig::Random, + ); + + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + fn test_validate_incompatible_policy() { + // RoundRobin with PD mode + let config = RouterConfig::new( + RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill:8000".to_string(), None)], + decode_urls: vec!["http://decode:8000".to_string()], + }, + PolicyConfig::RoundRobin, + ); + + let result = ConfigValidator::validate(&config); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("RoundRobin policy is not supported in PD disaggregated mode")); + } + + #[test] + fn test_validate_power_of_two_with_regular_mode() { + // PowerOfTwo with Regular mode should fail + let config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec![ + "http://worker1:8000".to_string(), + "http://worker2:8000".to_string(), + ], + }, + PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }, + ); + + let result = ConfigValidator::validate(&config); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("PowerOfTwo policy is only supported in PD disaggregated mode")); + } +} diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index dfe114f65..cd1f0c154 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -1,4 +1,5 @@ use pyo3::prelude::*; +pub mod config; pub mod logging; use std::collections::HashMap; pub mod openai_api_types; @@ -56,6 +57,83 @@ struct Router { decode_urls: Option>, } +impl Router { + /// Convert PyO3 Router to RouterConfig + pub fn to_router_config(&self) -> config::ConfigResult { + use config::{ + DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, + }; + + // Determine routing mode + let mode = if self.pd_disaggregation { + RoutingMode::PrefillDecode { + prefill_urls: self.prefill_urls.clone().unwrap_or_default(), + decode_urls: self.decode_urls.clone().unwrap_or_default(), + } + } else { + RoutingMode::Regular { + worker_urls: self.worker_urls.clone(), + } + }; + + // Convert policy + let policy = match self.policy { + PolicyType::Random => ConfigPolicyConfig::Random, + PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin, + PolicyType::CacheAware => ConfigPolicyConfig::CacheAware { + cache_threshold: self.cache_threshold, + balance_abs_threshold: self.balance_abs_threshold, + balance_rel_threshold: self.balance_rel_threshold, + eviction_interval_secs: self.eviction_interval_secs, + max_tree_size: self.max_tree_size, + }, + PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { + load_check_interval_secs: 5, // Default value + }, + }; + + // Service discovery configuration + let discovery = if self.service_discovery { + Some(DiscoveryConfig { + enabled: true, + namespace: self.service_discovery_namespace.clone(), + port: self.service_discovery_port, + check_interval_secs: 60, + selector: self.selector.clone(), + prefill_selector: self.prefill_selector.clone(), + decode_selector: self.decode_selector.clone(), + bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), + }) + } else { + None + }; + + // Metrics configuration + let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) { + (Some(port), Some(host)) => Some(MetricsConfig { + port, + host: host.clone(), + }), + _ => None, + }; + + Ok(config::RouterConfig { + mode, + policy, + host: self.host.clone(), + port: self.port, + max_payload_size: self.max_payload_size, + request_timeout_secs: self.request_timeout_secs, + worker_startup_timeout_secs: self.worker_startup_timeout_secs, + worker_startup_check_interval_secs: self.worker_startup_check_interval, + discovery, + metrics, + log_dir: self.log_dir.clone(), + verbose: self.verbose, + }) + } +} + #[pymethods] impl Router { #[new] @@ -149,68 +227,23 @@ impl Router { } fn start(&self) -> PyResult<()> { - let policy_config = if self.pd_disaggregation { - // PD mode - map PolicyType to PDSelectionPolicy - let pd_selection_policy = match &self.policy { - PolicyType::Random => pd_types::PDSelectionPolicy::Random, - PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo, - PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware { - cache_threshold: self.cache_threshold, - balance_abs_threshold: self.balance_abs_threshold, - balance_rel_threshold: self.balance_rel_threshold, - }, - PolicyType::RoundRobin => { - return Err(pyo3::exceptions::PyValueError::new_err( - "RoundRobin policy is not supported in PD disaggregated mode", - )); - } - }; + // Convert to RouterConfig and validate + let router_config = self.to_router_config().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e)) + })?; - let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| { - pyo3::exceptions::PyValueError::new_err( - "PD disaggregated mode requires prefill_urls", - ) - })?; - let decode_urls = self.decode_urls.as_ref().ok_or_else(|| { - pyo3::exceptions::PyValueError::new_err( - "PD disaggregated mode requires decode_urls", - ) - })?; + // Validate the configuration + router_config.validate().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Configuration validation failed: {}", + e + )) + })?; - router::PolicyConfig::PrefillDecodeConfig { - selection_policy: pd_selection_policy, - prefill_urls: prefill_urls.clone(), - decode_urls: decode_urls.clone(), - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - } - } else { - // Regular mode - match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - }, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - }, - PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - cache_threshold: self.cache_threshold, - balance_abs_threshold: self.balance_abs_threshold, - balance_rel_threshold: self.balance_rel_threshold, - eviction_interval_secs: self.eviction_interval_secs, - max_tree_size: self.max_tree_size, - }, - PolicyType::PowerOfTwo => { - return Err(pyo3::exceptions::PyValueError::new_err( - "PowerOfTwo policy is only supported in PD disaggregated mode", - )); - } - } - }; + // Convert to internal policy config + let policy_config = router_config + .to_routing_policy_config() + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; // Create service discovery config if enabled let service_discovery_config = if self.service_discovery {