sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
28
sgl-router/src/config/mod.rs
Normal file
28
sgl-router/src/config/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
pub mod types;
|
||||
pub mod validation;
|
||||
|
||||
pub use types::*;
|
||||
pub use validation::*;
|
||||
|
||||
/// Configuration errors
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("Validation failed: {reason}")]
|
||||
ValidationFailed { reason: String },
|
||||
|
||||
#[error("Invalid value for field '{field}': {value} - {reason}")]
|
||||
InvalidValue {
|
||||
field: String,
|
||||
value: String,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
#[error("Incompatible configuration: {reason}")]
|
||||
IncompatibleConfig { reason: String },
|
||||
|
||||
#[error("Missing required field: {field}")]
|
||||
MissingRequired { field: String },
|
||||
}
|
||||
|
||||
/// Result type for configuration operations
|
||||
pub type ConfigResult<T> = Result<T, ConfigError>;
|
||||
1241
sgl-router/src/config/types.rs
Normal file
1241
sgl-router/src/config/types.rs
Normal file
File diff suppressed because it is too large
Load Diff
776
sgl-router/src/config/validation.rs
Normal file
776
sgl-router/src/config/validation.rs
Normal file
@@ -0,0 +1,776 @@
|
||||
use super::*;
|
||||
|
||||
/// Configuration validator
|
||||
pub struct ConfigValidator;
|
||||
|
||||
impl ConfigValidator {
|
||||
/// Validate a complete router configuration
|
||||
pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// Check if service discovery is enabled
|
||||
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
|
||||
|
||||
Self::validate_mode(&config.mode, has_service_discovery)?;
|
||||
Self::validate_policy(&config.policy)?;
|
||||
Self::validate_server_settings(config)?;
|
||||
|
||||
if let Some(discovery) = &config.discovery {
|
||||
Self::validate_discovery(discovery, &config.mode)?;
|
||||
}
|
||||
|
||||
if let Some(metrics) = &config.metrics {
|
||||
Self::validate_metrics(metrics)?;
|
||||
}
|
||||
|
||||
Self::validate_compatibility(config)?;
|
||||
|
||||
// Validate effective retry/CB configs (respect disable flags)
|
||||
let retry_cfg = config.effective_retry_config();
|
||||
let cb_cfg = config.effective_circuit_breaker_config();
|
||||
Self::validate_retry(&retry_cfg)?;
|
||||
Self::validate_circuit_breaker(&cb_cfg)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate routing mode configuration
|
||||
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
|
||||
match mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
// Validate URLs if any are provided
|
||||
if !worker_urls.is_empty() {
|
||||
Self::validate_urls(worker_urls)?;
|
||||
}
|
||||
// Note: We allow empty worker URLs even without service discovery
|
||||
// to let the router start and fail at runtime when routing requests.
|
||||
// This matches legacy behavior and test expectations.
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
// Only require URLs if service discovery is disabled
|
||||
if !has_service_discovery {
|
||||
if prefill_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode requires at least one prefill worker URL".to_string(),
|
||||
});
|
||||
}
|
||||
if decode_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode requires at least one decode worker URL".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate URLs if any are provided
|
||||
if !prefill_urls.is_empty() {
|
||||
let prefill_url_strings: Vec<String> =
|
||||
prefill_urls.iter().map(|(url, _)| url.clone()).collect();
|
||||
Self::validate_urls(&prefill_url_strings)?;
|
||||
}
|
||||
if !decode_urls.is_empty() {
|
||||
Self::validate_urls(decode_urls)?;
|
||||
}
|
||||
|
||||
// Validate bootstrap ports
|
||||
for (_url, port) in prefill_urls {
|
||||
if let Some(port) = port {
|
||||
if *port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "bootstrap_port".to_string(),
|
||||
value: port.to_string(),
|
||||
reason: "Port must be between 1 and 65535".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate optional prefill and decode policies
|
||||
if let Some(p_policy) = prefill_policy {
|
||||
Self::validate_policy(p_policy)?;
|
||||
}
|
||||
if let Some(d_policy) = decode_policy {
|
||||
Self::validate_policy(d_policy)?;
|
||||
}
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls } => {
|
||||
// Require exactly one worker URL for OpenAI router
|
||||
if worker_urls.len() != 1 {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
|
||||
});
|
||||
}
|
||||
// Validate URL format
|
||||
if let Err(e) = url::Url::parse(&worker_urls[0]) {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate policy configuration
|
||||
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
|
||||
match policy {
|
||||
PolicyConfig::Random | PolicyConfig::RoundRobin => {
|
||||
// No specific validation needed
|
||||
}
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold: _,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
if !(0.0..=1.0).contains(cache_threshold) {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "cache_threshold".to_string(),
|
||||
value: cache_threshold.to_string(),
|
||||
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *balance_rel_threshold < 1.0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "balance_rel_threshold".to_string(),
|
||||
value: balance_rel_threshold.to_string(),
|
||||
reason: "Must be >= 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *eviction_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "eviction_interval_secs".to_string(),
|
||||
value: eviction_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *max_tree_size == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "max_tree_size".to_string(),
|
||||
value: max_tree_size.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs,
|
||||
} => {
|
||||
if *load_check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "load_check_interval_secs".to_string(),
|
||||
value: load_check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate server configuration
|
||||
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
|
||||
if config.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "port".to_string(),
|
||||
value: config.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.max_payload_size == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "max_payload_size".to_string(),
|
||||
value: config.max_payload_size.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.request_timeout_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "request_timeout_secs".to_string(),
|
||||
value: config.request_timeout_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.worker_startup_timeout_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_startup_timeout_secs".to_string(),
|
||||
value: config.worker_startup_timeout_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.worker_startup_check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_startup_check_interval_secs".to_string(),
|
||||
value: config.worker_startup_check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate service discovery configuration
|
||||
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
|
||||
if !discovery.enabled {
|
||||
return Ok(()); // No validation needed if disabled
|
||||
}
|
||||
|
||||
if discovery.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "discovery.port".to_string(),
|
||||
value: discovery.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if discovery.check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "discovery.check_interval_secs".to_string(),
|
||||
value: discovery.check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate selectors based on mode
|
||||
match mode {
|
||||
RoutingMode::Regular { .. } => {
|
||||
if discovery.selector.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "Regular mode with service discovery requires a non-empty selector"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RoutingMode::PrefillDecode { .. } => {
|
||||
if discovery.prefill_selector.is_empty() && discovery.decode_selector.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode with service discovery requires at least one non-empty selector (prefill or decode)".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
// OpenAI mode doesn't use service discovery
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "OpenAI mode does not support service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate metrics configuration
|
||||
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
|
||||
if metrics.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "metrics.port".to_string(),
|
||||
value: metrics.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if metrics.host.is_empty() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "metrics.host".to_string(),
|
||||
value: metrics.host.clone(),
|
||||
reason: "Host cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate retry configuration
|
||||
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
|
||||
if retry.max_retries < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_retries".to_string(),
|
||||
value: retry.max_retries.to_string(),
|
||||
reason: "Must be >= 1 (set to 1 to effectively disable retries)".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.initial_backoff_ms == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.initial_backoff_ms".to_string(),
|
||||
value: retry.initial_backoff_ms.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.max_backoff_ms < retry.initial_backoff_ms {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_backoff_ms".to_string(),
|
||||
value: retry.max_backoff_ms.to_string(),
|
||||
reason: "Must be >= initial_backoff_ms".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.backoff_multiplier < 1.0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.backoff_multiplier".to_string(),
|
||||
value: retry.backoff_multiplier.to_string(),
|
||||
reason: "Must be >= 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
if !(0.0..=1.0).contains(&retry.jitter_factor) {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.jitter_factor".to_string(),
|
||||
value: retry.jitter_factor.to_string(),
|
||||
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate circuit breaker configuration
|
||||
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
|
||||
if cb.failure_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.failure_threshold".to_string(),
|
||||
value: cb.failure_threshold.to_string(),
|
||||
reason: "Must be >= 1 (set to u32::MAX to effectively disable CB)".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.success_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.success_threshold".to_string(),
|
||||
value: cb.success_threshold.to_string(),
|
||||
reason: "Must be >= 1".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.timeout_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.timeout_duration_secs".to_string(),
|
||||
value: cb.timeout_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.window_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.window_duration_secs".to_string(),
|
||||
value: cb.window_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate compatibility between different configuration sections
|
||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// IGW mode is independent - skip other compatibility checks when enabled
|
||||
if config.enable_igw {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate gRPC connection mode requires tokenizer configuration
|
||||
if config.connection_mode == ConnectionMode::Grpc
|
||||
&& config.tokenizer_path.is_none()
|
||||
&& config.model_path.is_none()
|
||||
{
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
// No mode/policy restrictions needed anymore
|
||||
|
||||
// Check if service discovery is enabled for worker count validation
|
||||
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
|
||||
|
||||
// Only validate worker counts if service discovery is disabled
|
||||
if !has_service_discovery {
|
||||
// Check if power-of-two policy makes sense with insufficient workers
|
||||
if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
|
||||
let worker_count = config.mode.worker_count();
|
||||
if worker_count < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "Power-of-two policy requires at least 2 workers".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// For PD mode, validate that policies have sufficient workers
|
||||
if let RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} = &config.mode
|
||||
{
|
||||
// Check power-of-two for prefill
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
|
||||
if prefill_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check power-of-two for decode
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
|
||||
if decode_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason:
|
||||
"Power-of-two policy for decode requires at least 2 decode workers"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Service discovery is conflict with dp_aware routing for now
|
||||
// since it's not fully supported yet
|
||||
if has_service_discovery && config.dp_aware {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "DP-aware routing is not compatible with service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate URL format
|
||||
fn validate_urls(urls: &[String]) -> ConfigResult<()> {
|
||||
for url in urls {
|
||||
if url.is_empty() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if !url.starts_with("http://")
|
||||
&& !url.starts_with("https://")
|
||||
&& !url.starts_with("grpc://")
|
||||
{
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL must start with http://, https://, or grpc://".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Basic URL validation
|
||||
match ::url::Url::parse(url) {
|
||||
Ok(parsed) => {
|
||||
// Additional validation
|
||||
if parsed.host_str().is_none() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL must have a valid host".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: format!("Invalid URL format: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_regular_mode() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker:8000".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty_worker_urls() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Empty worker URLs are now allowed to match legacy behavior
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty_worker_urls_with_service_discovery() {
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Enable service discovery
|
||||
config.discovery = Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
selector: vec![("app".to_string(), "test".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Should pass validation since service discovery is enabled
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_urls() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["invalid-url".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_thresholds() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![
|
||||
"http://worker1:8000".to_string(),
|
||||
"http://worker2:8000".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 1.5, // Invalid: > 1.0
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_single_worker() {
|
||||
// Cache-aware with single worker should be allowed (even if not optimal)
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker1:8000".to_string()],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_roundrobin_with_pd_mode() {
|
||||
// RoundRobin with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::RoundRobin,
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_with_pd_mode() {
|
||||
// CacheAware with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_power_of_two_with_regular_mode() {
|
||||
// PowerOfTwo with Regular mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![
|
||||
"http://worker1:8000".to_string(),
|
||||
"http://worker2:8000".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
},
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_with_separate_policies() {
|
||||
// Test PD mode with different policies for prefill and decode
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8000".to_string(), None),
|
||||
("http://prefill2:8000".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
}),
|
||||
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}),
|
||||
},
|
||||
PolicyConfig::Random, // Main policy as fallback
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||
// Test that power-of-two policy requires at least 2 workers
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}), // Requires 2+ workers
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_requires_tokenizer() {
|
||||
// Test that gRPC connection mode requires tokenizer configuration
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Set connection mode to gRPC without tokenizer config
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = None;
|
||||
config.model_path = None;
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("gRPC connection mode requires"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_model_path() {
|
||||
// Test that gRPC works with model_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_tokenizer_path() {
|
||||
// Test that gRPC works with tokenizer_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user