[router] add different policies for p node and d node (#8395)

This commit is contained in:
Simo Lin
2025-07-27 00:39:20 -07:00
committed by GitHub
parent 0bcc195f4e
commit 2ab97023e3
10 changed files with 536 additions and 81 deletions

View File

@@ -46,6 +46,12 @@ pub enum RoutingMode {
prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs
decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
#[serde(skip_serializing_if = "Option::is_none")]
prefill_policy: Option<PolicyConfig>,
/// Optional separate policy for decode workers
#[serde(skip_serializing_if = "Option::is_none")]
decode_policy: Option<PolicyConfig>,
},
}
@@ -60,9 +66,32 @@ impl RoutingMode {
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => prefill_urls.len() + decode_urls.len(),
}
}
/// Get the effective prefill policy for PD mode
/// Falls back to the main policy if no specific prefill policy is set
pub fn get_prefill_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
match self {
RoutingMode::PrefillDecode { prefill_policy, .. } => {
prefill_policy.as_ref().unwrap_or(main_policy)
}
_ => main_policy,
}
}
/// Get the effective decode policy for PD mode
/// Falls back to the main policy if no specific decode policy is set
pub fn get_decode_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
match self {
RoutingMode::PrefillDecode { decode_policy, .. } => {
decode_policy.as_ref().unwrap_or(main_policy)
}
_ => main_policy,
}
}
}
/// Policy configuration for routing
@@ -307,6 +336,8 @@ mod tests {
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
};
assert!(pd.is_pd_mode());
}
@@ -332,6 +363,8 @@ mod tests {
"http://decode2".to_string(),
"http://decode3".to_string(),
],
prefill_policy: None,
decode_policy: None,
};
assert_eq!(pd.worker_count(), 5);
@@ -355,6 +388,8 @@ mod tests {
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
};
let json = serde_json::to_string(&pd).unwrap();
assert!(json.contains("\"type\":\"prefill_decode\""));
@@ -551,6 +586,8 @@ mod tests {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
},
..Default::default()
};
@@ -674,6 +711,8 @@ mod tests {
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: None,
decode_policy: None,
},
policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30,
@@ -800,4 +839,155 @@ mod tests {
Some("production".to_string())
);
}
// ============= Policy Fallback Tests =============
#[test]
fn test_pd_policy_fallback_both_specified() {
// When both prefill and decode policies are specified, they should be used
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
};
let main_policy = PolicyConfig::Random;
// Both specific policies should be used
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success
_ => panic!("Expected CacheAware for prefill"),
}
match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success
_ => panic!("Expected PowerOfTwo for decode"),
}
}
#[test]
fn test_pd_policy_fallback_only_prefill() {
// When only prefill policy is specified, decode should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: None,
};
let main_policy = PolicyConfig::RoundRobin;
// Prefill should use specific policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success
_ => panic!("Expected CacheAware for prefill"),
}
// Decode should fall back to main policy
match pd.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for decode"),
}
}
#[test]
fn test_pd_policy_fallback_only_decode() {
// When only decode policy is specified, prefill should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
};
let main_policy = PolicyConfig::Random;
// Prefill should fall back to main policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::Random => {} // Success
_ => panic!("Expected Random for prefill"),
}
// Decode should use specific policy
match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success
_ => panic!("Expected PowerOfTwo for decode"),
}
}
#[test]
fn test_pd_policy_fallback_none_specified() {
// When no specific policies are specified, both should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
};
let main_policy = PolicyConfig::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 20,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 2000,
};
// Both should fall back to main policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware {
cache_threshold, ..
} => {
assert!((cache_threshold - 0.7).abs() < 0.0001);
}
_ => panic!("Expected CacheAware for prefill"),
}
match pd.get_decode_policy(&main_policy) {
PolicyConfig::CacheAware {
cache_threshold, ..
} => {
assert!((cache_threshold - 0.7).abs() < 0.0001);
}
_ => panic!("Expected CacheAware for decode"),
}
}
#[test]
fn test_regular_mode_policy_fallback() {
// For regular mode, the helper methods should just return the main policy
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
let main_policy = PolicyConfig::RoundRobin;
// Both methods should return main policy for regular mode
match regular.get_prefill_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for regular mode"),
}
match regular.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for regular mode"),
}
}
}

View File

@@ -41,6 +41,8 @@ impl ConfigValidator {
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
// Only require URLs if service discovery is disabled
if !has_service_discovery {
@@ -78,6 +80,14 @@ impl ConfigValidator {
}
}
}
// Validate optional prefill and decode policies
if let Some(p_policy) = prefill_policy {
Self::validate_policy(p_policy)?;
}
if let Some(d_policy) = decode_policy {
Self::validate_policy(d_policy)?;
}
}
}
Ok(())
@@ -272,6 +282,35 @@ impl ConfigValidator {
});
}
}
// For PD mode, validate that policies have sufficient workers
if let RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} = &config.mode
{
// Check power-of-two for prefill
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
if prefill_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
});
}
}
// Check power-of-two for decode
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
if decode_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason:
"Power-of-two policy for decode requires at least 2 decode workers"
.to_string(),
});
}
}
}
}
Ok(())
@@ -430,6 +469,8 @@ mod tests {
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::Random,
);
@@ -444,6 +485,8 @@ mod tests {
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::RoundRobin,
);
@@ -459,6 +502,8 @@ mod tests {
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::CacheAware {
cache_threshold: 0.5,
@@ -491,4 +536,60 @@ mod tests {
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_pd_mode_with_separate_policies() {
// Test PD mode with different policies for prefill and decode
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8000".to_string(), None),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
},
PolicyConfig::Random, // Main policy as fallback
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
// Test that power-of-two policy requires at least 2 workers
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}), // Requires 2+ workers
decode_policy: None,
},
PolicyConfig::Random,
);
let result = ConfigValidator::validate(&config);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("prefill requires at least 2"));
}
}
}