diff --git a/sgl-router/README.md b/sgl-router/README.md index c899a6f59..baa894e1f 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -120,6 +120,16 @@ python -m sglang_router.launch_router \ --prefill-selector app=sglang component=prefill \ --decode-selector app=sglang component=decode \ --service-discovery-namespace sglang-system + +# With separate routing policies: +python -m sglang_router.launch_router \ + --pd-disaggregation \ + --prefill-policy cache_aware \ + --decode-policy power_of_two \ + --service-discovery \ + --prefill-selector app=sglang component=prefill \ + --decode-selector app=sglang component=decode \ + --service-discovery-namespace sglang-system ``` #### Kubernetes Pod Configuration @@ -226,7 +236,9 @@ python -m sglang_router.launch_router \ - `--decode`: Initial decode server URL - `--prefill-selector`: Label selector for prefill pods - `--decode-selector`: Label selector for decode pods -- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`) +- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`, `round_robin`) +- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill) +- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode) ## Development diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index f7aaf6dee..af1ce392c 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -40,6 +40,8 @@ class RouterArgs: # Routing policy policy: str = "cache_aware" + prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode + decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode worker_startup_timeout_secs: int = 300 worker_startup_check_interval: int = 10 cache_threshold: float = 0.5 @@ -108,7 +110,21 @@ class RouterArgs: type=str, default=RouterArgs.policy, choices=["random", "round_robin", "cache_aware", "power_of_two"], - help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode", + help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden", + ) + parser.add_argument( + f"--{prefix}prefill-policy", + type=str, + default=None, + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy", + ) + parser.add_argument( + f"--{prefix}decode-policy", + type=str, + default=None, + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", ) # PD-specific arguments @@ -266,6 +282,8 @@ class RouterArgs: prefill_urls=prefill_urls, decode_urls=decode_urls, policy=getattr(args, f"{prefix}policy"), + prefill_policy=getattr(args, f"{prefix}prefill_policy", None), + decode_policy=getattr(args, f"{prefix}decode_policy", None), worker_startup_timeout_secs=getattr( args, f"{prefix}worker_startup_timeout_secs" ), @@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: if not router_args.decode_urls: raise ValueError("PD disaggregation mode requires --decode") + # Warn about policy usage in PD mode + if ( + router_args.prefill_policy + and router_args.decode_policy + and router_args.policy + ): + logger.warning( + "Both --prefill-policy and --decode-policy are specified. " + "The main --policy flag will be ignored for PD mode." + ) + elif ( + router_args.prefill_policy + and not router_args.decode_policy + and router_args.policy + ): + logger.info( + f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes " + f"and --policy '{router_args.policy}' for decode nodes." + ) + elif ( + router_args.decode_policy + and not router_args.prefill_policy + and router_args.policy + ): + logger.info( + f"Using --policy '{router_args.policy}' for prefill nodes " + f"and --decode-policy '{router_args.decode_policy}' for decode nodes." + ) + # Create router with unified constructor router = Router( worker_urls=( @@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: decode_urls=( router_args.decode_urls if router_args.pd_disaggregation else None ), + prefill_policy=( + policy_from_str(router_args.prefill_policy) + if router_args.prefill_policy + else None + ), + decode_policy=( + policy_from_str(router_args.decode_policy) + if router_args.decode_policy + else None + ), ) router.start() @@ -455,12 +512,18 @@ Examples: # Regular mode python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 - # PD disaggregated mode + # PD disaggregated mode with same policy for both python -m sglang_router.launch_router --pd-disaggregation \\ --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ --decode http://decode1:8001 --decode http://decode2:8001 \\ --policy cache_aware + # PD mode with different policies for prefill and decode + python -m sglang_router.launch_router --pd-disaggregation \\ + --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ + --decode http://decode1:8001 --decode http://decode2:8001 \\ + --prefill-policy cache_aware --decode-policy power_of_two + """, formatter_class=CustomHelpFormatter, ) diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 4c5eed796..cd10e8e69 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -50,6 +50,10 @@ class Router: pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only) decode_urls: List of URLs for decode servers (PD mode only) + prefill_policy: Specific load balancing policy for prefill nodes (PD mode only). + If not specified, uses the main policy. Default: None + decode_policy: Specific load balancing policy for decode nodes (PD mode only). + If not specified, uses the main policy. Default: None """ def __init__( @@ -79,6 +83,8 @@ class Router: pd_disaggregation: bool = False, prefill_urls: Optional[List[tuple]] = None, decode_urls: Optional[List[str]] = None, + prefill_policy: Optional[PolicyType] = None, + decode_policy: Optional[PolicyType] = None, ): if selector is None: selector = {} @@ -113,6 +119,8 @@ class Router: pd_disaggregation=pd_disaggregation, prefill_urls=prefill_urls, decode_urls=decode_urls, + prefill_policy=prefill_policy, + decode_policy=decode_policy, ) def start(self) -> None: diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 5e25b2c3b..84075de4c 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -46,6 +46,12 @@ pub enum RoutingMode { prefill_urls: Vec<(String, Option)>, /// Decode worker URLs decode_urls: Vec, + /// Optional separate policy for prefill workers + #[serde(skip_serializing_if = "Option::is_none")] + prefill_policy: Option, + /// Optional separate policy for decode workers + #[serde(skip_serializing_if = "Option::is_none")] + decode_policy: Option, }, } @@ -60,9 +66,32 @@ impl RoutingMode { RoutingMode::PrefillDecode { prefill_urls, decode_urls, + .. } => prefill_urls.len() + decode_urls.len(), } } + + /// Get the effective prefill policy for PD mode + /// Falls back to the main policy if no specific prefill policy is set + pub fn get_prefill_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig { + match self { + RoutingMode::PrefillDecode { prefill_policy, .. } => { + prefill_policy.as_ref().unwrap_or(main_policy) + } + _ => main_policy, + } + } + + /// Get the effective decode policy for PD mode + /// Falls back to the main policy if no specific decode policy is set + pub fn get_decode_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig { + match self { + RoutingMode::PrefillDecode { decode_policy, .. } => { + decode_policy.as_ref().unwrap_or(main_policy) + } + _ => main_policy, + } + } } /// Policy configuration for routing @@ -307,6 +336,8 @@ mod tests { let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], decode_urls: vec!["http://decode1".to_string()], + prefill_policy: None, + decode_policy: None, }; assert!(pd.is_pd_mode()); } @@ -332,6 +363,8 @@ mod tests { "http://decode2".to_string(), "http://decode3".to_string(), ], + prefill_policy: None, + decode_policy: None, }; assert_eq!(pd.worker_count(), 5); @@ -355,6 +388,8 @@ mod tests { let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], decode_urls: vec!["http://decode1".to_string()], + prefill_policy: None, + decode_policy: None, }; let json = serde_json::to_string(&pd).unwrap(); assert!(json.contains("\"type\":\"prefill_decode\"")); @@ -551,6 +586,8 @@ mod tests { mode: RoutingMode::PrefillDecode { prefill_urls: vec![], decode_urls: vec![], + prefill_policy: None, + decode_policy: None, }, ..Default::default() }; @@ -674,6 +711,8 @@ mod tests { "http://decode1:8000".to_string(), "http://decode2:8000".to_string(), ], + prefill_policy: None, + decode_policy: None, }, policy: PolicyConfig::PowerOfTwo { load_check_interval_secs: 30, @@ -800,4 +839,155 @@ mod tests { Some("production".to_string()) ); } + + // ============= Policy Fallback Tests ============= + + #[test] + fn test_pd_policy_fallback_both_specified() { + // When both prefill and decode policies are specified, they should be used + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), None)], + decode_urls: vec!["http://decode1".to_string()], + prefill_policy: Some(PolicyConfig::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 60, + max_tree_size: 1000, + }), + decode_policy: Some(PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }), + }; + + let main_policy = PolicyConfig::Random; + + // Both specific policies should be used + match pd.get_prefill_policy(&main_policy) { + PolicyConfig::CacheAware { .. } => {} // Success + _ => panic!("Expected CacheAware for prefill"), + } + + match pd.get_decode_policy(&main_policy) { + PolicyConfig::PowerOfTwo { .. } => {} // Success + _ => panic!("Expected PowerOfTwo for decode"), + } + } + + #[test] + fn test_pd_policy_fallback_only_prefill() { + // When only prefill policy is specified, decode should use main policy + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), None)], + decode_urls: vec!["http://decode1".to_string()], + prefill_policy: Some(PolicyConfig::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 60, + max_tree_size: 1000, + }), + decode_policy: None, + }; + + let main_policy = PolicyConfig::RoundRobin; + + // Prefill should use specific policy + match pd.get_prefill_policy(&main_policy) { + PolicyConfig::CacheAware { .. } => {} // Success + _ => panic!("Expected CacheAware for prefill"), + } + + // Decode should fall back to main policy + match pd.get_decode_policy(&main_policy) { + PolicyConfig::RoundRobin => {} // Success + _ => panic!("Expected RoundRobin for decode"), + } + } + + #[test] + fn test_pd_policy_fallback_only_decode() { + // When only decode policy is specified, prefill should use main policy + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), None)], + decode_urls: vec!["http://decode1".to_string()], + prefill_policy: None, + decode_policy: Some(PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }), + }; + + let main_policy = PolicyConfig::Random; + + // Prefill should fall back to main policy + match pd.get_prefill_policy(&main_policy) { + PolicyConfig::Random => {} // Success + _ => panic!("Expected Random for prefill"), + } + + // Decode should use specific policy + match pd.get_decode_policy(&main_policy) { + PolicyConfig::PowerOfTwo { .. } => {} // Success + _ => panic!("Expected PowerOfTwo for decode"), + } + } + + #[test] + fn test_pd_policy_fallback_none_specified() { + // When no specific policies are specified, both should use main policy + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), None)], + decode_urls: vec!["http://decode1".to_string()], + prefill_policy: None, + decode_policy: None, + }; + + let main_policy = PolicyConfig::CacheAware { + cache_threshold: 0.7, + balance_abs_threshold: 20, + balance_rel_threshold: 1.5, + eviction_interval_secs: 300, + max_tree_size: 2000, + }; + + // Both should fall back to main policy + match pd.get_prefill_policy(&main_policy) { + PolicyConfig::CacheAware { + cache_threshold, .. + } => { + assert!((cache_threshold - 0.7).abs() < 0.0001); + } + _ => panic!("Expected CacheAware for prefill"), + } + + match pd.get_decode_policy(&main_policy) { + PolicyConfig::CacheAware { + cache_threshold, .. + } => { + assert!((cache_threshold - 0.7).abs() < 0.0001); + } + _ => panic!("Expected CacheAware for decode"), + } + } + + #[test] + fn test_regular_mode_policy_fallback() { + // For regular mode, the helper methods should just return the main policy + let regular = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }; + + let main_policy = PolicyConfig::RoundRobin; + + // Both methods should return main policy for regular mode + match regular.get_prefill_policy(&main_policy) { + PolicyConfig::RoundRobin => {} // Success + _ => panic!("Expected RoundRobin for regular mode"), + } + + match regular.get_decode_policy(&main_policy) { + PolicyConfig::RoundRobin => {} // Success + _ => panic!("Expected RoundRobin for regular mode"), + } + } } diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 381fcce07..1e78a0f10 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -41,6 +41,8 @@ impl ConfigValidator { RoutingMode::PrefillDecode { prefill_urls, decode_urls, + prefill_policy, + decode_policy, } => { // Only require URLs if service discovery is disabled if !has_service_discovery { @@ -78,6 +80,14 @@ impl ConfigValidator { } } } + + // Validate optional prefill and decode policies + if let Some(p_policy) = prefill_policy { + Self::validate_policy(p_policy)?; + } + if let Some(d_policy) = decode_policy { + Self::validate_policy(d_policy)?; + } } } Ok(()) @@ -272,6 +282,35 @@ impl ConfigValidator { }); } } + + // For PD mode, validate that policies have sufficient workers + if let RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + prefill_policy, + decode_policy, + } = &config.mode + { + // Check power-of-two for prefill + if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy { + if prefill_urls.len() < 2 { + return Err(ConfigError::IncompatibleConfig { + reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(), + }); + } + } + + // Check power-of-two for decode + if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy { + if decode_urls.len() < 2 { + return Err(ConfigError::IncompatibleConfig { + reason: + "Power-of-two policy for decode requires at least 2 decode workers" + .to_string(), + }); + } + } + } } Ok(()) @@ -430,6 +469,8 @@ mod tests { RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))], decode_urls: vec!["http://decode:8000".to_string()], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::Random, ); @@ -444,6 +485,8 @@ mod tests { RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], decode_urls: vec!["http://decode:8000".to_string()], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::RoundRobin, ); @@ -459,6 +502,8 @@ mod tests { RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], decode_urls: vec!["http://decode:8000".to_string()], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::CacheAware { cache_threshold: 0.5, @@ -491,4 +536,60 @@ mod tests { let result = ConfigValidator::validate(&config); assert!(result.is_ok()); } + + #[test] + fn test_validate_pd_mode_with_separate_policies() { + // Test PD mode with different policies for prefill and decode + let config = RouterConfig::new( + RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1:8000".to_string(), None), + ("http://prefill2:8000".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8000".to_string(), + "http://decode2:8000".to_string(), + ], + prefill_policy: Some(PolicyConfig::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 60, + max_tree_size: 1000, + }), + decode_policy: Some(PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }), + }, + PolicyConfig::Random, // Main policy as fallback + ); + + let result = ConfigValidator::validate(&config); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_pd_mode_power_of_two_insufficient_workers() { + // Test that power-of-two policy requires at least 2 workers + let config = RouterConfig::new( + RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill + decode_urls: vec![ + "http://decode1:8000".to_string(), + "http://decode2:8000".to_string(), + ], + prefill_policy: Some(PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }), // Requires 2+ workers + decode_policy: None, + }, + PolicyConfig::Random, + ); + + let result = ConfigValidator::validate(&config); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("prefill requires at least 2")); + } + } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index a37a4b474..0c03bd497 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -54,6 +54,8 @@ struct Router { // PD-specific fields (only used when pd_disaggregation is true) prefill_urls: Option)>>, decode_urls: Option>, + prefill_policy: Option, + decode_policy: Option, } impl Router { @@ -63,11 +65,31 @@ impl Router { DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, }; + // Convert policy helper function + let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig { + match policy { + PolicyType::Random => ConfigPolicyConfig::Random, + PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin, + PolicyType::CacheAware => ConfigPolicyConfig::CacheAware { + cache_threshold: self.cache_threshold, + balance_abs_threshold: self.balance_abs_threshold, + balance_rel_threshold: self.balance_rel_threshold, + eviction_interval_secs: self.eviction_interval_secs, + max_tree_size: self.max_tree_size, + }, + PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { + load_check_interval_secs: 5, // Default value + }, + } + }; + // Determine routing mode let mode = if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(), + prefill_policy: self.prefill_policy.as_ref().map(convert_policy), + decode_policy: self.decode_policy.as_ref().map(convert_policy), } } else { RoutingMode::Regular { @@ -75,21 +97,8 @@ impl Router { } }; - // Convert policy - let policy = match self.policy { - PolicyType::Random => ConfigPolicyConfig::Random, - PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin, - PolicyType::CacheAware => ConfigPolicyConfig::CacheAware { - cache_threshold: self.cache_threshold, - balance_abs_threshold: self.balance_abs_threshold, - balance_rel_threshold: self.balance_rel_threshold, - eviction_interval_secs: self.eviction_interval_secs, - max_tree_size: self.max_tree_size, - }, - PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { - load_check_interval_secs: 5, // Default value - }, - }; + // Convert main policy + let policy = convert_policy(&self.policy); // Service discovery configuration let discovery = if self.service_discovery { @@ -163,7 +172,9 @@ impl Router { request_timeout_secs = 600, // Add configurable request timeout pd_disaggregation = false, // New flag for PD mode prefill_urls = None, - decode_urls = None + decode_urls = None, + prefill_policy = None, + decode_policy = None ))] fn new( worker_urls: Vec, @@ -193,6 +204,8 @@ impl Router { pd_disaggregation: bool, prefill_urls: Option)>>, decode_urls: Option>, + prefill_policy: Option, + decode_policy: Option, ) -> PyResult { Ok(Router { host, @@ -222,6 +235,8 @@ impl Router { pd_disaggregation, prefill_urls, decode_urls, + prefill_policy, + decode_policy, }) } diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 9e30c0d01..bfbe4b93a 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy { decode_workers: &[Box], request_text: Option<&str>, ) -> Option<(usize, usize)> { - // In PD mode: + // DEPRECATED: This method is no longer used when separate policies are configured. + // The PD router now uses separate policies for prefill and decode selection. + // This implementation remains for backward compatibility when a single policy is used. + + // In PD mode with single policy: // - Prefill: Use cache-aware routing for better cache utilization // - Decode: Use least-load routing for better load distribution diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 201240121..edf063440 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -17,7 +17,16 @@ impl RouterFactory { RoutingMode::PrefillDecode { prefill_urls, decode_urls, - } => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config), + prefill_policy, + decode_policy, + } => Self::create_pd_router( + prefill_urls, + decode_urls, + prefill_policy.as_ref(), + decode_policy.as_ref(), + &config.policy, + config, + ), } } @@ -45,18 +54,23 @@ impl RouterFactory { fn create_pd_router( prefill_urls: &[(String, Option)], decode_urls: &[String], - policy_config: &PolicyConfig, + prefill_policy_config: Option<&PolicyConfig>, + decode_policy_config: Option<&PolicyConfig>, + main_policy_config: &PolicyConfig, router_config: &RouterConfig, ) -> Result, String> { - // Create policy directly from PolicyConfig - // All policies now support PD mode through the select_worker_pair method - let policy = PolicyFactory::create_from_config(policy_config); + // Create policies - use specific policies if provided, otherwise fall back to main policy + let prefill_policy = + PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config)); + let decode_policy = + PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); - // Create PD router with injected policy + // Create PD router with separate policies let router = PDRouter::new( prefill_urls.to_vec(), decode_urls.to_vec(), - policy, + prefill_policy, + decode_policy, router_config.worker_startup_timeout_secs, router_config.worker_startup_check_interval_secs, )?; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index ab9927d24..507ac1f42 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -22,8 +22,10 @@ use uuid::Uuid; pub struct PDRouter { pub prefill_workers: Arc>>>, pub decode_workers: Arc>>>, - pub policy: Arc, + pub prefill_policy: Arc, + pub decode_policy: Arc, pub prefill_tree: Option>>, + pub decode_tree: Option>>, pub timeout_secs: u64, pub interval_secs: u64, pub worker_loads: Arc>>, @@ -66,7 +68,7 @@ impl PDRouter { workers.push(worker); - // Add to cache tree if using cache-aware policy + // Add to cache tree if using cache-aware policy for prefill if let Some(ref tree) = self.prefill_tree { tree.lock().unwrap().insert("", &url); } @@ -102,6 +104,11 @@ impl PDRouter { workers.push(worker); + // Add to cache tree if using cache-aware policy for decode + if let Some(ref tree) = self.decode_tree { + tree.lock().unwrap().insert("", &url); + } + info!("Added decode server: {}", url); Ok(format!("Successfully added decode server: {}", url)) } @@ -126,12 +133,7 @@ impl PDRouter { // Remove from cache tree if using cache-aware policy if let Some(ref tree) = self.prefill_tree { - // Note: Tree doesn't have a remove method, so we rebuild it - let mut tree_guard = tree.lock().unwrap(); - *tree_guard = Tree::new(); - for worker in workers.iter() { - tree_guard.insert("", worker.url()); - } + tree.lock().unwrap().remove_tenant(url); } info!("Removed prefill server: {}", url); @@ -156,6 +158,11 @@ impl PDRouter { }); } + // Remove from the cache tree if using cache-aware policy for decode + if let Some(ref tree) = self.decode_tree { + tree.lock().unwrap().remove_tenant(url); + } + info!("Removed decode server: {}", url); Ok(format!("Successfully removed decode server: {}", url)) } @@ -163,7 +170,8 @@ impl PDRouter { pub fn new( prefill_urls: Vec<(String, Option)>, decode_urls: Vec, - policy: Arc, + prefill_policy: Arc, + decode_policy: Arc, timeout_secs: u64, interval_secs: u64, ) -> Result { @@ -192,10 +200,10 @@ impl PDRouter { )?; } - // Initialize cache-aware components if needed - let prefill_tree = if policy.name() == "cache_aware" { + // Initialize cache-aware components if needed for prefill policy + let prefill_tree = if prefill_policy.name() == "cache_aware" { // Initialize the policy's internal tree with prefill workers - if let Some(cache_policy) = policy + if let Some(cache_policy) = prefill_policy .as_any() .downcast_ref::() { @@ -212,6 +220,26 @@ impl PDRouter { None }; + // Initialize cache-aware components if needed for decode policy + let decode_tree = if decode_policy.name() == "cache_aware" { + // Initialize the policy's internal tree with decode workers + if let Some(cache_policy) = decode_policy + .as_any() + .downcast_ref::() + { + cache_policy.init_workers(&decode_workers); + } + + let tree = Arc::new(Mutex::new(Tree::new())); + // Initialize tree with decode workers + for worker in &decode_workers { + tree.lock().unwrap().insert("", worker.url()); + } + Some(tree) + } else { + None + }; + // Set up background load monitoring for power-of-two selection let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); @@ -222,25 +250,28 @@ impl PDRouter { .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - let load_monitor_handle = if policy.name() == "power_of_two" { - let monitor_urls = all_urls.clone(); - let monitor_interval = interval_secs; - let monitor_client = http_client.clone(); - let policy_clone = Arc::clone(&policy); + let load_monitor_handle = + if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { + let monitor_urls = all_urls.clone(); + let monitor_interval = interval_secs; + let monitor_client = http_client.clone(); + let prefill_policy_clone = Arc::clone(&prefill_policy); + let decode_policy_clone = Arc::clone(&decode_policy); - Some(Arc::new(tokio::spawn(async move { - Self::monitor_worker_loads_with_client( - monitor_urls, - tx, - monitor_interval, - monitor_client, - policy_clone, - ) - .await; - }))) - } else { - None - }; + Some(Arc::new(tokio::spawn(async move { + Self::monitor_worker_loads_with_client( + monitor_urls, + tx, + monitor_interval, + monitor_client, + prefill_policy_clone, + decode_policy_clone, + ) + .await; + }))) + } else { + None + }; let prefill_workers = Arc::new(RwLock::new(prefill_workers)); let decode_workers = Arc::new(RwLock::new(decode_workers)); @@ -254,8 +285,10 @@ impl PDRouter { Ok(PDRouter { prefill_workers, decode_workers, - policy, + prefill_policy, + decode_policy, prefill_tree, + decode_tree, timeout_secs, interval_secs, worker_loads, @@ -736,18 +769,21 @@ impl PDRouter { return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); } - // Use the policy to select worker pair - match self - .policy - .select_worker_pair(&prefill_workers, &decode_workers, request_text) - { - Some((prefill_idx, decode_idx)) => { - let prefill = prefill_workers[prefill_idx].clone_worker(); - let decode = decode_workers[decode_idx].clone_worker(); - Ok((prefill, decode)) - } - None => Err("Failed to select worker pair".to_string()), - } + // Select prefill worker using prefill policy + let prefill_idx = self + .prefill_policy + .select_worker(&prefill_workers, request_text) + .ok_or("Failed to select prefill worker")?; + + // Select decode worker using decode policy + let decode_idx = self + .decode_policy + .select_worker(&decode_workers, request_text) + .ok_or("Failed to select decode worker")?; + + let prefill = prefill_workers[prefill_idx].clone_worker(); + let decode = decode_workers[decode_idx].clone_worker(); + Ok((prefill, decode)) } // Background task to monitor worker loads with shared client @@ -756,7 +792,8 @@ impl PDRouter { tx: tokio::sync::watch::Sender>, interval_secs: u64, client: reqwest::Client, - policy: Arc, + prefill_policy: Arc, + decode_policy: Arc, ) { loop { let mut loads = HashMap::new(); @@ -781,8 +818,9 @@ impl PDRouter { debug!("Worker loads updated: {:?}", loads); - // Update the policy with current loads - policy.update_loads(&loads); + // Update both policies with current loads + prefill_policy.update_loads(&loads); + decode_policy.update_loads(&loads); // Check if receiver is still active if tx.send(loads).is_err() { @@ -1463,13 +1501,16 @@ mod tests { use actix_web::test::TestRequest; fn create_test_pd_router() -> PDRouter { - let policy = Arc::new(RandomPolicy::new()); + let prefill_policy = Arc::new(RandomPolicy::new()); + let decode_policy = Arc::new(RandomPolicy::new()); PDRouter { prefill_workers: Arc::new(RwLock::new(vec![])), decode_workers: Arc::new(RwLock::new(vec![])), - policy, + prefill_policy, + decode_policy, prefill_tree: None, + decode_tree: None, timeout_secs: 5, interval_secs: 1, worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), @@ -1608,9 +1649,9 @@ mod tests { #[tokio::test] async fn test_cache_tree_operations() { - let policy = Arc::new(CacheAwarePolicy::new()); + let cache_policy = Arc::new(CacheAwarePolicy::new()); let mut router = create_test_pd_router(); - router.policy = policy; + router.prefill_policy = cache_policy; // Initialize cache tree let tree = Arc::new(Mutex::new(Tree::new())); @@ -1638,9 +1679,9 @@ mod tests { #[tokio::test] async fn test_cache_tree_rebuild_on_remove() { - let policy = Arc::new(CacheAwarePolicy::new()); + let cache_policy = Arc::new(CacheAwarePolicy::new()); let mut router = create_test_pd_router(); - router.policy = policy; + router.prefill_policy = cache_policy; // Initialize cache tree let tree = Arc::new(Mutex::new(Tree::new())); @@ -1880,9 +1921,10 @@ mod tests { #[tokio::test] async fn test_load_monitor_updates() { - let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); + let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); let mut router = create_test_pd_router(); - router.policy = policy; + router.prefill_policy = power_of_two_policy.clone(); + router.decode_policy = power_of_two_policy; // Create load channel let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index a2c0d7e31..24571eb24 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -122,6 +122,8 @@ mod test_pd_routing { "http://decode1:8080".to_string(), "http://decode2:8080".to_string(), ], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::Random, ), @@ -129,6 +131,8 @@ mod test_pd_routing { RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], decode_urls: vec!["http://decode:8080".to_string()], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::PowerOfTwo { load_check_interval_secs: 5, @@ -142,6 +146,8 @@ mod test_pd_routing { ("http://p3:8080".to_string(), Some(9002)), ], decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], + prefill_policy: None, + decode_policy: None, }, PolicyConfig::CacheAware { cache_threshold: 0.7,