[router] add different policies for p node and d node (#8395)
This commit is contained in:
@@ -46,6 +46,12 @@ pub enum RoutingMode {
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
/// Decode worker URLs
|
||||
decode_urls: Vec<String>,
|
||||
/// Optional separate policy for prefill workers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prefill_policy: Option<PolicyConfig>,
|
||||
/// Optional separate policy for decode workers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
decode_policy: Option<PolicyConfig>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -60,9 +66,32 @@ impl RoutingMode {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
..
|
||||
} => prefill_urls.len() + decode_urls.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the effective prefill policy for PD mode
|
||||
/// Falls back to the main policy if no specific prefill policy is set
|
||||
pub fn get_prefill_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
|
||||
match self {
|
||||
RoutingMode::PrefillDecode { prefill_policy, .. } => {
|
||||
prefill_policy.as_ref().unwrap_or(main_policy)
|
||||
}
|
||||
_ => main_policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the effective decode policy for PD mode
|
||||
/// Falls back to the main policy if no specific decode policy is set
|
||||
pub fn get_decode_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
|
||||
match self {
|
||||
RoutingMode::PrefillDecode { decode_policy, .. } => {
|
||||
decode_policy.as_ref().unwrap_or(main_policy)
|
||||
}
|
||||
_ => main_policy,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Policy configuration for routing
|
||||
@@ -307,6 +336,8 @@ mod tests {
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
};
|
||||
assert!(pd.is_pd_mode());
|
||||
}
|
||||
@@ -332,6 +363,8 @@ mod tests {
|
||||
"http://decode2".to_string(),
|
||||
"http://decode3".to_string(),
|
||||
],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
};
|
||||
assert_eq!(pd.worker_count(), 5);
|
||||
|
||||
@@ -355,6 +388,8 @@ mod tests {
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
};
|
||||
let json = serde_json::to_string(&pd).unwrap();
|
||||
assert!(json.contains("\"type\":\"prefill_decode\""));
|
||||
@@ -551,6 +586,8 @@ mod tests {
|
||||
mode: RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![],
|
||||
decode_urls: vec![],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
@@ -674,6 +711,8 @@ mod tests {
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
policy: PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 30,
|
||||
@@ -800,4 +839,155 @@ mod tests {
|
||||
Some("production".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
// ============= Policy Fallback Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_both_specified() {
|
||||
// When both prefill and decode policies are specified, they should be used
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
}),
|
||||
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}),
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::Random;
|
||||
|
||||
// Both specific policies should be used
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware { .. } => {} // Success
|
||||
_ => panic!("Expected CacheAware for prefill"),
|
||||
}
|
||||
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||
_ => panic!("Expected PowerOfTwo for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_only_prefill() {
|
||||
// When only prefill policy is specified, decode should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
}),
|
||||
decode_policy: None,
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::RoundRobin;
|
||||
|
||||
// Prefill should use specific policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware { .. } => {} // Success
|
||||
_ => panic!("Expected CacheAware for prefill"),
|
||||
}
|
||||
|
||||
// Decode should fall back to main policy
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
_ => panic!("Expected RoundRobin for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_only_decode() {
|
||||
// When only decode policy is specified, prefill should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}),
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::Random;
|
||||
|
||||
// Prefill should fall back to main policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::Random => {} // Success
|
||||
_ => panic!("Expected Random for prefill"),
|
||||
}
|
||||
|
||||
// Decode should use specific policy
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||
_ => panic!("Expected PowerOfTwo for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_none_specified() {
|
||||
// When no specific policies are specified, both should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 20,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 300,
|
||||
max_tree_size: 2000,
|
||||
};
|
||||
|
||||
// Both should fall back to main policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold, ..
|
||||
} => {
|
||||
assert!((cache_threshold - 0.7).abs() < 0.0001);
|
||||
}
|
||||
_ => panic!("Expected CacheAware for prefill"),
|
||||
}
|
||||
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold, ..
|
||||
} => {
|
||||
assert!((cache_threshold - 0.7).abs() < 0.0001);
|
||||
}
|
||||
_ => panic!("Expected CacheAware for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regular_mode_policy_fallback() {
|
||||
// For regular mode, the helper methods should just return the main policy
|
||||
let regular = RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker1".to_string()],
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::RoundRobin;
|
||||
|
||||
// Both methods should return main policy for regular mode
|
||||
match regular.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
_ => panic!("Expected RoundRobin for regular mode"),
|
||||
}
|
||||
|
||||
match regular.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
_ => panic!("Expected RoundRobin for regular mode"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ impl ConfigValidator {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
// Only require URLs if service discovery is disabled
|
||||
if !has_service_discovery {
|
||||
@@ -78,6 +80,14 @@ impl ConfigValidator {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate optional prefill and decode policies
|
||||
if let Some(p_policy) = prefill_policy {
|
||||
Self::validate_policy(p_policy)?;
|
||||
}
|
||||
if let Some(d_policy) = decode_policy {
|
||||
Self::validate_policy(d_policy)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@@ -272,6 +282,35 @@ impl ConfigValidator {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// For PD mode, validate that policies have sufficient workers
|
||||
if let RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} = &config.mode
|
||||
{
|
||||
// Check power-of-two for prefill
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
|
||||
if prefill_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check power-of-two for decode
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
|
||||
if decode_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason:
|
||||
"Power-of-two policy for decode requires at least 2 decode workers"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -430,6 +469,8 @@ mod tests {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
@@ -444,6 +485,8 @@ mod tests {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::RoundRobin,
|
||||
);
|
||||
@@ -459,6 +502,8 @@ mod tests {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
@@ -491,4 +536,60 @@ mod tests {
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_with_separate_policies() {
|
||||
// Test PD mode with different policies for prefill and decode
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8000".to_string(), None),
|
||||
("http://prefill2:8000".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
}),
|
||||
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}),
|
||||
},
|
||||
PolicyConfig::Random, // Main policy as fallback
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||
// Test that power-of-two policy requires at least 2 workers
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}), // Requires 2+ workers
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user