[router] add centralized configuration module for sgl-router (#7588)
This commit is contained in:
@@ -36,6 +36,7 @@ metrics-exporter-prometheus = "0.17.0"
|
|||||||
# Added for request tracing
|
# Added for request tracing
|
||||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
|
url = "2.5.4"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|||||||
28
sgl-router/src/config/mod.rs
Normal file
28
sgl-router/src/config/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
pub mod types;
|
||||||
|
pub mod validation;
|
||||||
|
|
||||||
|
pub use types::*;
|
||||||
|
pub use validation::*;
|
||||||
|
|
||||||
|
/// Configuration errors
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ConfigError {
|
||||||
|
#[error("Validation failed: {reason}")]
|
||||||
|
ValidationFailed { reason: String },
|
||||||
|
|
||||||
|
#[error("Invalid value for field '{field}': {value} - {reason}")]
|
||||||
|
InvalidValue {
|
||||||
|
field: String,
|
||||||
|
value: String,
|
||||||
|
reason: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Incompatible configuration: {reason}")]
|
||||||
|
IncompatibleConfig { reason: String },
|
||||||
|
|
||||||
|
#[error("Missing required field: {field}")]
|
||||||
|
MissingRequired { field: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result type for configuration operations
|
||||||
|
pub type ConfigResult<T> = Result<T, ConfigError>;
|
||||||
298
sgl-router/src/config/types.rs
Normal file
298
sgl-router/src/config/types.rs
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
use super::{ConfigError, ConfigResult};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Main router configuration
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RouterConfig {
|
||||||
|
/// Routing mode configuration
|
||||||
|
pub mode: RoutingMode,
|
||||||
|
/// Policy configuration
|
||||||
|
pub policy: PolicyConfig,
|
||||||
|
/// Server host address
|
||||||
|
pub host: String,
|
||||||
|
/// Server port
|
||||||
|
pub port: u16,
|
||||||
|
/// Maximum payload size in bytes
|
||||||
|
pub max_payload_size: usize,
|
||||||
|
/// Request timeout in seconds
|
||||||
|
pub request_timeout_secs: u64,
|
||||||
|
/// Worker startup timeout in seconds
|
||||||
|
pub worker_startup_timeout_secs: u64,
|
||||||
|
/// Worker health check interval in seconds
|
||||||
|
pub worker_startup_check_interval_secs: u64,
|
||||||
|
/// Service discovery configuration (optional)
|
||||||
|
pub discovery: Option<DiscoveryConfig>,
|
||||||
|
/// Metrics configuration (optional)
|
||||||
|
pub metrics: Option<MetricsConfig>,
|
||||||
|
/// Log directory (None = stdout only)
|
||||||
|
pub log_dir: Option<String>,
|
||||||
|
/// Verbose logging
|
||||||
|
pub verbose: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Routing mode configuration
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum RoutingMode {
|
||||||
|
#[serde(rename = "regular")]
|
||||||
|
Regular {
|
||||||
|
/// List of worker URLs
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "prefill_decode")]
|
||||||
|
PrefillDecode {
|
||||||
|
/// Prefill worker URLs with optional bootstrap ports
|
||||||
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
|
/// Decode worker URLs
|
||||||
|
decode_urls: Vec<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RoutingMode {
|
||||||
|
pub fn is_pd_mode(&self) -> bool {
|
||||||
|
matches!(self, RoutingMode::PrefillDecode { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn worker_count(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
RoutingMode::Regular { worker_urls } => worker_urls.len(),
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
} => prefill_urls.len() + decode_urls.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Policy configuration for routing
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum PolicyConfig {
|
||||||
|
#[serde(rename = "random")]
|
||||||
|
Random,
|
||||||
|
|
||||||
|
#[serde(rename = "round_robin")]
|
||||||
|
RoundRobin,
|
||||||
|
|
||||||
|
#[serde(rename = "cache_aware")]
|
||||||
|
CacheAware {
|
||||||
|
/// Minimum prefix match ratio to use cache-based routing
|
||||||
|
cache_threshold: f32,
|
||||||
|
/// Absolute load difference threshold for load balancing
|
||||||
|
balance_abs_threshold: usize,
|
||||||
|
/// Relative load ratio threshold for load balancing
|
||||||
|
balance_rel_threshold: f32,
|
||||||
|
/// Interval between cache eviction cycles (seconds)
|
||||||
|
eviction_interval_secs: u64,
|
||||||
|
/// Maximum cache tree size per tenant
|
||||||
|
max_tree_size: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[serde(rename = "power_of_two")]
|
||||||
|
PowerOfTwo {
|
||||||
|
/// Interval for load monitoring (seconds)
|
||||||
|
load_check_interval_secs: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PolicyConfig {
|
||||||
|
pub fn name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
PolicyConfig::Random => "random",
|
||||||
|
PolicyConfig::RoundRobin => "round_robin",
|
||||||
|
PolicyConfig::CacheAware { .. } => "cache_aware",
|
||||||
|
PolicyConfig::PowerOfTwo { .. } => "power_of_two",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Service discovery configuration
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DiscoveryConfig {
|
||||||
|
/// Enable service discovery
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Kubernetes namespace (None = all namespaces)
|
||||||
|
pub namespace: Option<String>,
|
||||||
|
/// Service discovery port
|
||||||
|
pub port: u16,
|
||||||
|
/// Check interval for service discovery
|
||||||
|
pub check_interval_secs: u64,
|
||||||
|
/// Regular mode selector
|
||||||
|
pub selector: HashMap<String, String>,
|
||||||
|
/// PD mode prefill selector
|
||||||
|
pub prefill_selector: HashMap<String, String>,
|
||||||
|
/// PD mode decode selector
|
||||||
|
pub decode_selector: HashMap<String, String>,
|
||||||
|
/// Bootstrap port annotation key
|
||||||
|
pub bootstrap_port_annotation: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DiscoveryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: false,
|
||||||
|
namespace: None,
|
||||||
|
port: 8000,
|
||||||
|
check_interval_secs: 60,
|
||||||
|
selector: HashMap::new(),
|
||||||
|
prefill_selector: HashMap::new(),
|
||||||
|
decode_selector: HashMap::new(),
|
||||||
|
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Metrics configuration
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MetricsConfig {
|
||||||
|
/// Prometheus metrics port
|
||||||
|
pub port: u16,
|
||||||
|
/// Prometheus metrics host
|
||||||
|
pub host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MetricsConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 29000,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RouterConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
mode: RoutingMode::Regular {
|
||||||
|
worker_urls: vec![],
|
||||||
|
},
|
||||||
|
policy: PolicyConfig::Random,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 3001,
|
||||||
|
max_payload_size: 268_435_456, // 256MB
|
||||||
|
request_timeout_secs: 600,
|
||||||
|
worker_startup_timeout_secs: 300,
|
||||||
|
worker_startup_check_interval_secs: 10,
|
||||||
|
discovery: None,
|
||||||
|
metrics: None,
|
||||||
|
log_dir: None,
|
||||||
|
verbose: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouterConfig {
|
||||||
|
/// Create a new configuration with mode and policy
|
||||||
|
pub fn new(mode: RoutingMode, policy: PolicyConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
mode,
|
||||||
|
policy,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate the configuration
|
||||||
|
pub fn validate(&self) -> ConfigResult<()> {
|
||||||
|
crate::config::validation::ConfigValidator::validate(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the routing mode type as a string
|
||||||
|
pub fn mode_type(&self) -> &'static str {
|
||||||
|
match self.mode {
|
||||||
|
RoutingMode::Regular { .. } => "regular",
|
||||||
|
RoutingMode::PrefillDecode { .. } => "prefill_decode",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if service discovery is enabled
|
||||||
|
pub fn has_service_discovery(&self) -> bool {
|
||||||
|
self.discovery.as_ref().map_or(false, |d| d.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if metrics are enabled
|
||||||
|
pub fn has_metrics(&self) -> bool {
|
||||||
|
self.metrics.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to routing PolicyConfig for internal use
|
||||||
|
pub fn to_routing_policy_config(&self) -> ConfigResult<crate::router::PolicyConfig> {
|
||||||
|
match (&self.mode, &self.policy) {
|
||||||
|
(
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
},
|
||||||
|
policy,
|
||||||
|
) => {
|
||||||
|
// Map policy to PDSelectionPolicy
|
||||||
|
let selection_policy = match policy {
|
||||||
|
PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random,
|
||||||
|
PolicyConfig::PowerOfTwo { .. } => {
|
||||||
|
crate::pd_types::PDSelectionPolicy::PowerOfTwo
|
||||||
|
}
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold,
|
||||||
|
balance_abs_threshold,
|
||||||
|
balance_rel_threshold,
|
||||||
|
..
|
||||||
|
} => crate::pd_types::PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: *cache_threshold,
|
||||||
|
balance_abs_threshold: *balance_abs_threshold,
|
||||||
|
balance_rel_threshold: *balance_rel_threshold,
|
||||||
|
},
|
||||||
|
PolicyConfig::RoundRobin => {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "RoundRobin policy is not supported in PD disaggregated mode"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(crate::router::PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy,
|
||||||
|
prefill_urls: prefill_urls.clone(),
|
||||||
|
decode_urls: decode_urls.clone(),
|
||||||
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
interval_secs: self.worker_startup_check_interval_secs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
(RoutingMode::Regular { .. }, PolicyConfig::Random) => {
|
||||||
|
Ok(crate::router::PolicyConfig::RandomConfig {
|
||||||
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
interval_secs: self.worker_startup_check_interval_secs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
(RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => {
|
||||||
|
Ok(crate::router::PolicyConfig::RoundRobinConfig {
|
||||||
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
interval_secs: self.worker_startup_check_interval_secs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
(
|
||||||
|
RoutingMode::Regular { .. },
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold,
|
||||||
|
balance_abs_threshold,
|
||||||
|
balance_rel_threshold,
|
||||||
|
eviction_interval_secs,
|
||||||
|
max_tree_size,
|
||||||
|
},
|
||||||
|
) => Ok(crate::router::PolicyConfig::CacheAwareConfig {
|
||||||
|
cache_threshold: *cache_threshold,
|
||||||
|
balance_abs_threshold: *balance_abs_threshold,
|
||||||
|
balance_rel_threshold: *balance_rel_threshold,
|
||||||
|
eviction_interval_secs: *eviction_interval_secs,
|
||||||
|
max_tree_size: *max_tree_size,
|
||||||
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
interval_secs: self.worker_startup_check_interval_secs,
|
||||||
|
}),
|
||||||
|
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => {
|
||||||
|
Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "PowerOfTwo policy is only supported in PD disaggregated mode"
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
496
sgl-router/src/config/validation.rs
Normal file
496
sgl-router/src/config/validation.rs
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Configuration validator
|
||||||
|
pub struct ConfigValidator;
|
||||||
|
|
||||||
|
impl ConfigValidator {
|
||||||
|
/// Validate a complete router configuration
|
||||||
|
pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
|
||||||
|
// Check if service discovery is enabled
|
||||||
|
let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled);
|
||||||
|
|
||||||
|
Self::validate_mode(&config.mode, has_service_discovery)?;
|
||||||
|
Self::validate_policy(&config.policy)?;
|
||||||
|
Self::validate_server_settings(config)?;
|
||||||
|
|
||||||
|
if let Some(discovery) = &config.discovery {
|
||||||
|
Self::validate_discovery(discovery, &config.mode)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(metrics) = &config.metrics {
|
||||||
|
Self::validate_metrics(metrics)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::validate_compatibility(config)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate routing mode configuration
|
||||||
|
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
|
||||||
|
match mode {
|
||||||
|
RoutingMode::Regular { worker_urls } => {
|
||||||
|
// Validate URLs if any are provided
|
||||||
|
if !worker_urls.is_empty() {
|
||||||
|
Self::validate_urls(worker_urls)?;
|
||||||
|
}
|
||||||
|
// Note: We allow empty worker URLs even without service discovery
|
||||||
|
// to let the router start and fail at runtime when routing requests.
|
||||||
|
// This matches legacy behavior and test expectations.
|
||||||
|
}
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
} => {
|
||||||
|
// Only require URLs if service discovery is disabled
|
||||||
|
if !has_service_discovery {
|
||||||
|
if prefill_urls.is_empty() {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "PD mode requires at least one prefill worker URL".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if decode_urls.is_empty() {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "PD mode requires at least one decode worker URL".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate URLs if any are provided
|
||||||
|
if !prefill_urls.is_empty() {
|
||||||
|
let prefill_url_strings: Vec<String> =
|
||||||
|
prefill_urls.iter().map(|(url, _)| url.clone()).collect();
|
||||||
|
Self::validate_urls(&prefill_url_strings)?;
|
||||||
|
}
|
||||||
|
if !decode_urls.is_empty() {
|
||||||
|
Self::validate_urls(decode_urls)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate bootstrap ports
|
||||||
|
for (_url, port) in prefill_urls {
|
||||||
|
if let Some(port) = port {
|
||||||
|
if *port == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "bootstrap_port".to_string(),
|
||||||
|
value: port.to_string(),
|
||||||
|
reason: "Port must be between 1 and 65535".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate policy configuration
|
||||||
|
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
|
||||||
|
match policy {
|
||||||
|
PolicyConfig::Random | PolicyConfig::RoundRobin => {
|
||||||
|
// No specific validation needed
|
||||||
|
}
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold,
|
||||||
|
balance_abs_threshold: _,
|
||||||
|
balance_rel_threshold,
|
||||||
|
eviction_interval_secs,
|
||||||
|
max_tree_size,
|
||||||
|
} => {
|
||||||
|
if !(0.0..=1.0).contains(cache_threshold) {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "cache_threshold".to_string(),
|
||||||
|
value: cache_threshold.to_string(),
|
||||||
|
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if *balance_rel_threshold < 1.0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "balance_rel_threshold".to_string(),
|
||||||
|
value: balance_rel_threshold.to_string(),
|
||||||
|
reason: "Must be >= 1.0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if *eviction_interval_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "eviction_interval_secs".to_string(),
|
||||||
|
value: eviction_interval_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if *max_tree_size == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "max_tree_size".to_string(),
|
||||||
|
value: max_tree_size.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs,
|
||||||
|
} => {
|
||||||
|
if *load_check_interval_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "load_check_interval_secs".to_string(),
|
||||||
|
value: load_check_interval_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate server configuration
|
||||||
|
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
|
||||||
|
if config.port == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "port".to_string(),
|
||||||
|
value: config.port.to_string(),
|
||||||
|
reason: "Port must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.max_payload_size == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "max_payload_size".to_string(),
|
||||||
|
value: config.max_payload_size.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.request_timeout_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "request_timeout_secs".to_string(),
|
||||||
|
value: config.request_timeout_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.worker_startup_timeout_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_startup_timeout_secs".to_string(),
|
||||||
|
value: config.worker_startup_timeout_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.worker_startup_check_interval_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_startup_check_interval_secs".to_string(),
|
||||||
|
value: config.worker_startup_check_interval_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate service discovery configuration
|
||||||
|
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
|
||||||
|
if !discovery.enabled {
|
||||||
|
return Ok(()); // No validation needed if disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if discovery.port == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "discovery.port".to_string(),
|
||||||
|
value: discovery.port.to_string(),
|
||||||
|
reason: "Port must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if discovery.check_interval_secs == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "discovery.check_interval_secs".to_string(),
|
||||||
|
value: discovery.check_interval_secs.to_string(),
|
||||||
|
reason: "Must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate selectors based on mode
|
||||||
|
match mode {
|
||||||
|
RoutingMode::Regular { .. } => {
|
||||||
|
if discovery.selector.is_empty() {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "Regular mode with service discovery requires a non-empty selector"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RoutingMode::PrefillDecode { .. } => {
|
||||||
|
if discovery.prefill_selector.is_empty() && discovery.decode_selector.is_empty() {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "PD mode with service discovery requires at least one non-empty selector (prefill or decode)".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate metrics configuration
|
||||||
|
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
|
||||||
|
if metrics.port == 0 {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "metrics.port".to_string(),
|
||||||
|
value: metrics.port.to_string(),
|
||||||
|
reason: "Port must be > 0".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if metrics.host.is_empty() {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "metrics.host".to_string(),
|
||||||
|
value: metrics.host.clone(),
|
||||||
|
reason: "Host cannot be empty".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate compatibility between different configuration sections
|
||||||
|
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||||
|
// Check mode and policy compatibility
|
||||||
|
match (&config.mode, &config.policy) {
|
||||||
|
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => {
|
||||||
|
// PowerOfTwo is only supported in PD mode
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "PowerOfTwo policy is only supported in PD disaggregated mode"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(RoutingMode::PrefillDecode { .. }, PolicyConfig::RoundRobin) => {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "RoundRobin policy is not supported in PD disaggregated mode"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if service discovery is enabled for worker count validation
|
||||||
|
let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled);
|
||||||
|
|
||||||
|
// Only validate worker counts if service discovery is disabled
|
||||||
|
if !has_service_discovery {
|
||||||
|
// Check if power-of-two policy makes sense with insufficient workers
|
||||||
|
if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
|
||||||
|
let worker_count = config.mode.worker_count();
|
||||||
|
if worker_count < 2 {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "Power-of-two policy requires at least 2 workers".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate URL format
|
||||||
|
fn validate_urls(urls: &[String]) -> ConfigResult<()> {
|
||||||
|
for url in urls {
|
||||||
|
if url.is_empty() {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_url".to_string(),
|
||||||
|
value: url.clone(),
|
||||||
|
reason: "URL cannot be empty".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_url".to_string(),
|
||||||
|
value: url.clone(),
|
||||||
|
reason: "URL must start with http:// or https://".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic URL validation
|
||||||
|
match ::url::Url::parse(url) {
|
||||||
|
Ok(parsed) => {
|
||||||
|
// Additional validation
|
||||||
|
if parsed.host_str().is_none() {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_url".to_string(),
|
||||||
|
value: url.clone(),
|
||||||
|
reason: "URL must have a valid host".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(ConfigError::InvalidValue {
|
||||||
|
field: "worker_url".to_string(),
|
||||||
|
value: url.clone(),
|
||||||
|
reason: format!("Invalid URL format: {}", e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_regular_mode() {
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["http://worker:8000".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(ConfigValidator::validate(&config).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_empty_worker_urls() {
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Empty worker URLs are now allowed to match legacy behavior
|
||||||
|
assert!(ConfigValidator::validate(&config).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_empty_worker_urls_with_service_discovery() {
|
||||||
|
let mut config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Enable service discovery
|
||||||
|
config.discovery = Some(DiscoveryConfig {
|
||||||
|
enabled: true,
|
||||||
|
selector: vec![("app".to_string(), "test".to_string())]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should pass validation since service discovery is enabled
|
||||||
|
assert!(ConfigValidator::validate(&config).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_invalid_urls() {
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["invalid-url".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(ConfigValidator::validate(&config).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_cache_aware_thresholds() {
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![
|
||||||
|
"http://worker1:8000".to_string(),
|
||||||
|
"http://worker2:8000".to_string(),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 1.5, // Invalid: > 1.0
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
max_tree_size: 1000,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(ConfigValidator::validate(&config).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_cache_aware_single_worker() {
|
||||||
|
// Cache-aware with single worker should be allowed (even if not optimal)
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["http://worker1:8000".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
max_tree_size: 1000,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(ConfigValidator::validate(&config).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_pd_mode() {
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
||||||
|
decode_urls: vec!["http://decode:8000".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(ConfigValidator::validate(&config).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_incompatible_policy() {
|
||||||
|
// RoundRobin with PD mode
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||||
|
decode_urls: vec!["http://decode:8000".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::RoundRobin,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.to_string()
|
||||||
|
.contains("RoundRobin policy is not supported in PD disaggregated mode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_power_of_two_with_regular_mode() {
|
||||||
|
// PowerOfTwo with Regular mode should fail
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![
|
||||||
|
"http://worker1:8000".to_string(),
|
||||||
|
"http://worker2:8000".to_string(),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 60,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.to_string()
|
||||||
|
.contains("PowerOfTwo policy is only supported in PD disaggregated mode"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
pub mod config;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
pub mod openai_api_types;
|
pub mod openai_api_types;
|
||||||
@@ -56,6 +57,83 @@ struct Router {
|
|||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Router {
|
||||||
|
/// Convert PyO3 Router to RouterConfig
|
||||||
|
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||||
|
use config::{
|
||||||
|
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine routing mode
|
||||||
|
let mode = if self.pd_disaggregation {
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
||||||
|
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: self.worker_urls.clone(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert policy
|
||||||
|
let policy = match self.policy {
|
||||||
|
PolicyType::Random => ConfigPolicyConfig::Random,
|
||||||
|
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
|
||||||
|
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
|
||||||
|
cache_threshold: self.cache_threshold,
|
||||||
|
balance_abs_threshold: self.balance_abs_threshold,
|
||||||
|
balance_rel_threshold: self.balance_rel_threshold,
|
||||||
|
eviction_interval_secs: self.eviction_interval_secs,
|
||||||
|
max_tree_size: self.max_tree_size,
|
||||||
|
},
|
||||||
|
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 5, // Default value
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Service discovery configuration
|
||||||
|
let discovery = if self.service_discovery {
|
||||||
|
Some(DiscoveryConfig {
|
||||||
|
enabled: true,
|
||||||
|
namespace: self.service_discovery_namespace.clone(),
|
||||||
|
port: self.service_discovery_port,
|
||||||
|
check_interval_secs: 60,
|
||||||
|
selector: self.selector.clone(),
|
||||||
|
prefill_selector: self.prefill_selector.clone(),
|
||||||
|
decode_selector: self.decode_selector.clone(),
|
||||||
|
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Metrics configuration
|
||||||
|
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
|
||||||
|
(Some(port), Some(host)) => Some(MetricsConfig {
|
||||||
|
port,
|
||||||
|
host: host.clone(),
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(config::RouterConfig {
|
||||||
|
mode,
|
||||||
|
policy,
|
||||||
|
host: self.host.clone(),
|
||||||
|
port: self.port,
|
||||||
|
max_payload_size: self.max_payload_size,
|
||||||
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
|
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||||
|
discovery,
|
||||||
|
metrics,
|
||||||
|
log_dir: self.log_dir.clone(),
|
||||||
|
verbose: self.verbose,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Router {
|
impl Router {
|
||||||
#[new]
|
#[new]
|
||||||
@@ -149,68 +227,23 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn start(&self) -> PyResult<()> {
|
fn start(&self) -> PyResult<()> {
|
||||||
let policy_config = if self.pd_disaggregation {
|
// Convert to RouterConfig and validate
|
||||||
// PD mode - map PolicyType to PDSelectionPolicy
|
let router_config = self.to_router_config().map_err(|e| {
|
||||||
let pd_selection_policy = match &self.policy {
|
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
|
||||||
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
|
||||||
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
|
|
||||||
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
|
|
||||||
cache_threshold: self.cache_threshold,
|
|
||||||
balance_abs_threshold: self.balance_abs_threshold,
|
|
||||||
balance_rel_threshold: self.balance_rel_threshold,
|
|
||||||
},
|
|
||||||
PolicyType::RoundRobin => {
|
|
||||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
||||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
|
|
||||||
pyo3::exceptions::PyValueError::new_err(
|
|
||||||
"PD disaggregated mode requires prefill_urls",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
|
|
||||||
pyo3::exceptions::PyValueError::new_err(
|
|
||||||
"PD disaggregated mode requires decode_urls",
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
router::PolicyConfig::PrefillDecodeConfig {
|
// Validate the configuration
|
||||||
selection_policy: pd_selection_policy,
|
router_config.validate().map_err(|e| {
|
||||||
prefill_urls: prefill_urls.clone(),
|
pyo3::exceptions::PyValueError::new_err(format!(
|
||||||
decode_urls: decode_urls.clone(),
|
"Configuration validation failed: {}",
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
e
|
||||||
interval_secs: self.worker_startup_check_interval,
|
))
|
||||||
}
|
})?;
|
||||||
} else {
|
|
||||||
// Regular mode
|
// Convert to internal policy config
|
||||||
match &self.policy {
|
let policy_config = router_config
|
||||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
.to_routing_policy_config()
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
|
||||||
interval_secs: self.worker_startup_check_interval,
|
|
||||||
},
|
|
||||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
|
||||||
interval_secs: self.worker_startup_check_interval,
|
|
||||||
},
|
|
||||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
|
||||||
interval_secs: self.worker_startup_check_interval,
|
|
||||||
cache_threshold: self.cache_threshold,
|
|
||||||
balance_abs_threshold: self.balance_abs_threshold,
|
|
||||||
balance_rel_threshold: self.balance_rel_threshold,
|
|
||||||
eviction_interval_secs: self.eviction_interval_secs,
|
|
||||||
max_tree_size: self.max_tree_size,
|
|
||||||
},
|
|
||||||
PolicyType::PowerOfTwo => {
|
|
||||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
||||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create service discovery config if enabled
|
// Create service discovery config if enabled
|
||||||
let service_discovery_config = if self.service_discovery {
|
let service_discovery_config = if self.service_discovery {
|
||||||
|
|||||||
Reference in New Issue
Block a user