[router] add ut for pd request, metrics and config (#8184)

This commit is contained in:
Simo Lin
2025-07-20 10:53:42 -07:00
committed by GitHub
parent 465968b2e3
commit 1fc455e8b6
4 changed files with 2005 additions and 74 deletions

View File

@@ -214,83 +214,590 @@ impl RouterConfig {
pub fn has_metrics(&self) -> bool {
self.metrics.is_some()
}
}
/* Commented out - no longer needed without compatibility layer
/// Convert to routing PolicyConfig for internal use
pub fn to_routing_policy_config(&self) -> ConfigResult<crate::router::PolicyConfig> {
match (&self.mode, &self.policy) {
(
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
},
policy,
) => {
// Map policy to PDSelectionPolicy
let selection_policy = match policy {
PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random,
PolicyConfig::PowerOfTwo { .. } => {
crate::pd_types::PDSelectionPolicy::PowerOfTwo
}
PolicyConfig::CacheAware { .. } => {
return Err(ConfigError::IncompatibleConfig {
reason: "CacheAware policy is not supported in PD disaggregated mode"
.to_string(),
});
}
PolicyConfig::RoundRobin => {
return Err(ConfigError::IncompatibleConfig {
reason: "RoundRobin policy is not supported in PD disaggregated mode"
.to_string(),
});
}
};
#[cfg(test)]
mod tests {
use super::*;
Ok(crate::router::PolicyConfig::PrefillDecodeConfig {
selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
// ============= RouterConfig Tests =============
#[test]
fn test_router_config_default() {
let config = RouterConfig::default();
assert!(
matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty())
);
assert!(matches!(config.policy, PolicyConfig::Random));
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001);
assert_eq!(config.max_payload_size, 268_435_456);
assert_eq!(config.request_timeout_secs, 600);
assert_eq!(config.worker_startup_timeout_secs, 300);
assert_eq!(config.worker_startup_check_interval_secs, 10);
assert!(config.discovery.is_none());
assert!(config.metrics.is_none());
assert!(config.log_dir.is_none());
assert!(config.log_level.is_none());
}
#[test]
fn test_router_config_new() {
let mode = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string(), "http://worker2".to_string()],
};
let policy = PolicyConfig::RoundRobin;
let config = RouterConfig::new(mode, policy);
match config.mode {
RoutingMode::Regular { worker_urls } => {
assert_eq!(worker_urls.len(), 2);
assert_eq!(worker_urls[0], "http://worker1");
assert_eq!(worker_urls[1], "http://worker2");
}
(RoutingMode::Regular { .. }, PolicyConfig::Random) => {
Ok(crate::router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
}
(RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => {
Ok(crate::router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
}
(
RoutingMode::Regular { .. },
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
},
) => Ok(crate::router::PolicyConfig::CacheAwareConfig {
cache_threshold: *cache_threshold,
balance_abs_threshold: *balance_abs_threshold,
balance_rel_threshold: *balance_rel_threshold,
eviction_interval_secs: *eviction_interval_secs,
max_tree_size: *max_tree_size,
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
}),
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => {
Err(ConfigError::IncompatibleConfig {
reason: "PowerOfTwo policy is only supported in PD disaggregated mode"
.to_string(),
})
_ => panic!("Expected Regular mode"),
}
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
// Other fields should be default
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001);
}
#[test]
fn test_router_config_serialization() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::Random,
host: "0.0.0.0".to_string(),
port: 8080,
max_payload_size: 1024,
request_timeout_secs: 30,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
discovery: Some(DiscoveryConfig::default()),
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.host, deserialized.host);
assert_eq!(config.port, deserialized.port);
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
assert!(deserialized.discovery.is_some());
assert!(deserialized.metrics.is_some());
}
// ============= RoutingMode Tests =============
#[test]
fn test_routing_mode_is_pd_mode() {
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
assert!(!regular.is_pd_mode());
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
};
assert!(pd.is_pd_mode());
}
#[test]
fn test_routing_mode_worker_count() {
let regular = RoutingMode::Regular {
worker_urls: vec![
"http://worker1".to_string(),
"http://worker2".to_string(),
"http://worker3".to_string(),
],
};
assert_eq!(regular.worker_count(), 3);
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1".to_string(), Some(8001)),
("http://prefill2".to_string(), None),
],
decode_urls: vec![
"http://decode1".to_string(),
"http://decode2".to_string(),
"http://decode3".to_string(),
],
};
assert_eq!(pd.worker_count(), 5);
let empty_regular = RoutingMode::Regular {
worker_urls: vec![],
};
assert_eq!(empty_regular.worker_count(), 0);
}
#[test]
fn test_routing_mode_serialization() {
// Test Regular mode
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
let json = serde_json::to_string(&regular).unwrap();
assert!(json.contains("\"type\":\"regular\""));
assert!(json.contains("\"worker_urls\""));
// Test PrefillDecode mode
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
};
let json = serde_json::to_string(&pd).unwrap();
assert!(json.contains("\"type\":\"prefill_decode\""));
assert!(json.contains("\"prefill_urls\""));
assert!(json.contains("\"decode_urls\""));
}
// ============= PolicyConfig Tests =============
#[test]
fn test_policy_config_name() {
assert_eq!(PolicyConfig::Random.name(), "random");
assert_eq!(PolicyConfig::RoundRobin.name(), "round_robin");
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.8,
balance_abs_threshold: 10,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 1000,
};
assert_eq!(cache_aware.name(), "cache_aware");
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
};
assert_eq!(power_of_two.name(), "power_of_two");
}
#[test]
fn test_policy_config_serialization() {
// Test Random
let random = PolicyConfig::Random;
let json = serde_json::to_string(&random).unwrap();
assert_eq!(json, r#"{"type":"random"}"#);
// Test CacheAware with all parameters
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.8,
balance_abs_threshold: 10,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 1000,
};
let json = serde_json::to_string(&cache_aware).unwrap();
assert!(json.contains("\"type\":\"cache_aware\""));
assert!(json.contains("\"cache_threshold\":0.8"));
assert!(json.contains("\"balance_abs_threshold\":10"));
// Test PowerOfTwo
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
};
let json = serde_json::to_string(&power_of_two).unwrap();
assert!(json.contains("\"type\":\"power_of_two\""));
assert!(json.contains("\"load_check_interval_secs\":60"));
}
#[test]
fn test_cache_aware_parameters() {
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.75,
balance_abs_threshold: 20,
balance_rel_threshold: 2.0,
eviction_interval_secs: 600,
max_tree_size: 5000,
};
match cache_aware {
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
assert!((cache_threshold - 0.75).abs() < 0.0001);
assert_eq!(balance_abs_threshold, 20);
assert!((balance_rel_threshold - 2.0).abs() < 0.0001);
assert_eq!(eviction_interval_secs, 600);
assert_eq!(max_tree_size, 5000);
}
_ => panic!("Expected CacheAware"),
}
}
*/
#[test]
fn test_power_of_two_parameters() {
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 120,
};
match power_of_two {
PolicyConfig::PowerOfTwo {
load_check_interval_secs,
} => {
assert_eq!(load_check_interval_secs, 120);
}
_ => panic!("Expected PowerOfTwo"),
}
}
// ============= DiscoveryConfig Tests =============
#[test]
fn test_discovery_config_default() {
let config = DiscoveryConfig::default();
assert!(!config.enabled);
assert!(config.namespace.is_none());
assert_eq!(config.port, 8000);
assert_eq!(config.check_interval_secs, 60);
assert!(config.selector.is_empty());
assert!(config.prefill_selector.is_empty());
assert!(config.decode_selector.is_empty());
assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port");
}
#[test]
fn test_discovery_config_with_selectors() {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
selector.insert("role".to_string(), "worker".to_string());
let config = DiscoveryConfig {
enabled: true,
namespace: Some("default".to_string()),
port: 9000,
check_interval_secs: 30,
selector: selector.clone(),
prefill_selector: selector.clone(),
decode_selector: selector.clone(),
bootstrap_port_annotation: "custom.io/port".to_string(),
};
assert!(config.enabled);
assert_eq!(config.namespace, Some("default".to_string()));
assert_eq!(config.port, 9000);
assert_eq!(config.selector.len(), 2);
assert_eq!(config.selector.get("app"), Some(&"sglang".to_string()));
}
#[test]
fn test_discovery_config_namespace() {
// Test None namespace (all namespaces)
let config = DiscoveryConfig {
namespace: None,
..Default::default()
};
assert!(config.namespace.is_none());
// Test specific namespace
let config = DiscoveryConfig {
namespace: Some("production".to_string()),
..Default::default()
};
assert_eq!(config.namespace, Some("production".to_string()));
}
// ============= MetricsConfig Tests =============
#[test]
fn test_metrics_config_default() {
let config = MetricsConfig::default();
assert_eq!(config.port, 29000);
assert_eq!(config.host, "127.0.0.1");
}
#[test]
fn test_metrics_config_custom() {
let config = MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
};
assert_eq!(config.port, 9090);
assert_eq!(config.host, "0.0.0.0");
}
// ============= RouterConfig Utility Methods Tests =============
#[test]
fn test_mode_type() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
..Default::default()
};
assert_eq!(config.mode_type(), "regular");
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
decode_urls: vec![],
},
..Default::default()
};
assert_eq!(config.mode_type(), "prefill_decode");
}
#[test]
fn test_has_service_discovery() {
let config = RouterConfig::default();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
assert!(config.has_service_discovery());
}
#[test]
fn test_has_metrics() {
let config = RouterConfig::default();
assert!(!config.has_metrics());
let config = RouterConfig {
metrics: Some(MetricsConfig::default()),
..Default::default()
};
assert!(config.has_metrics());
}
// ============= Edge Cases =============
#[test]
fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
let mode = RoutingMode::Regular {
worker_urls: large_urls.clone(),
};
assert_eq!(mode.worker_count(), 1000);
// Test serialization with large list
let config = RouterConfig {
mode,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
match deserialized.mode {
RoutingMode::Regular { worker_urls } => {
assert_eq!(worker_urls.len(), 1000);
}
_ => panic!("Expected Regular mode"),
}
}
#[test]
fn test_unicode_in_config() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()],
},
log_dir: Some("/日志/目录".to_string()),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
match deserialized.mode {
RoutingMode::Regular { worker_urls } => {
assert_eq!(worker_urls[0], "http://работник1");
assert_eq!(worker_urls[1], "http://工作者2");
}
_ => panic!("Expected Regular mode"),
}
assert_eq!(deserialized.log_dir, Some("/日志/目录".to_string()));
}
#[test]
fn test_empty_string_fields() {
let config = RouterConfig {
host: "".to_string(),
log_dir: Some("".to_string()),
log_level: Some("".to_string()),
..Default::default()
};
assert_eq!(config.host, "");
assert_eq!(config.log_dir, Some("".to_string()));
assert_eq!(config.log_level, Some("".to_string()));
}
// ============= Complex Configuration Tests =============
#[test]
fn test_full_pd_mode_config() {
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
},
policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30,
},
host: "0.0.0.0".to_string(),
port: 3000,
max_payload_size: 1048576,
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
..Default::default()
}),
metrics: Some(MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
}),
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
};
assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 4);
assert_eq!(config.policy.name(), "power_of_two");
assert!(config.has_service_discovery());
assert!(config.has_metrics());
}
#[test]
fn test_full_regular_mode_config() {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
],
},
policy: PolicyConfig::CacheAware {
cache_threshold: 0.9,
balance_abs_threshold: 5,
balance_rel_threshold: 1.2,
eviction_interval_secs: 600,
max_tree_size: 10000,
},
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536870912,
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: None,
port: 8080,
check_interval_secs: 45,
selector,
..Default::default()
}),
metrics: Some(MetricsConfig::default()),
log_dir: None,
log_level: Some("debug".to_string()),
};
assert!(!config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 3);
assert_eq!(config.policy.name(), "cache_aware");
assert!(config.has_service_discovery());
assert!(config.has_metrics());
}
#[test]
fn test_config_with_all_options() {
let mut selectors = HashMap::new();
selectors.insert("env".to_string(), "prod".to_string());
selectors.insert("version".to_string(), "v1".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::RoundRobin,
host: "::1".to_string(), // IPv6
port: 8888,
max_payload_size: 1024 * 1024 * 512, // 512MB
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),
port: 8443,
check_interval_secs: 120,
selector: selectors.clone(),
prefill_selector: selectors.clone(),
decode_selector: selectors,
bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(),
}),
metrics: Some(MetricsConfig {
port: 9999,
host: "::".to_string(), // IPv6 any
}),
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
};
assert!(config.has_service_discovery());
assert!(config.has_metrics());
assert_eq!(config.mode_type(), "regular");
// Test round-trip serialization
let json = serde_json::to_string_pretty(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.host, "::1");
assert_eq!(deserialized.port, 8888);
assert_eq!(
deserialized.discovery.unwrap().namespace,
Some("production".to_string())
);
}
}