[router] add different policies for p node and d node (#8395)
This commit is contained in:
@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
|
|||||||
--prefill-selector app=sglang component=prefill \
|
--prefill-selector app=sglang component=prefill \
|
||||||
--decode-selector app=sglang component=decode \
|
--decode-selector app=sglang component=decode \
|
||||||
--service-discovery-namespace sglang-system
|
--service-discovery-namespace sglang-system
|
||||||
|
|
||||||
|
# With separate routing policies:
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--pd-disaggregation \
|
||||||
|
--prefill-policy cache_aware \
|
||||||
|
--decode-policy power_of_two \
|
||||||
|
--service-discovery \
|
||||||
|
--prefill-selector app=sglang component=prefill \
|
||||||
|
--decode-selector app=sglang component=decode \
|
||||||
|
--service-discovery-namespace sglang-system
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Kubernetes Pod Configuration
|
#### Kubernetes Pod Configuration
|
||||||
@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
|
|||||||
- `--decode`: Initial decode server URL
|
- `--decode`: Initial decode server URL
|
||||||
- `--prefill-selector`: Label selector for prefill pods
|
- `--prefill-selector`: Label selector for prefill pods
|
||||||
- `--decode-selector`: Label selector for decode pods
|
- `--decode-selector`: Label selector for decode pods
|
||||||
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`)
|
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`, `round_robin`)
|
||||||
|
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
|
||||||
|
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ class RouterArgs:
|
|||||||
|
|
||||||
# Routing policy
|
# Routing policy
|
||||||
policy: str = "cache_aware"
|
policy: str = "cache_aware"
|
||||||
|
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||||
|
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||||
worker_startup_timeout_secs: int = 300
|
worker_startup_timeout_secs: int = 300
|
||||||
worker_startup_check_interval: int = 10
|
worker_startup_check_interval: int = 10
|
||||||
cache_threshold: float = 0.5
|
cache_threshold: float = 0.5
|
||||||
@@ -108,7 +110,21 @@ class RouterArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=RouterArgs.policy,
|
default=RouterArgs.policy,
|
||||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
|
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill-policy",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
|
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode-policy",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
|
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||||
)
|
)
|
||||||
|
|
||||||
# PD-specific arguments
|
# PD-specific arguments
|
||||||
@@ -266,6 +282,8 @@ class RouterArgs:
|
|||||||
prefill_urls=prefill_urls,
|
prefill_urls=prefill_urls,
|
||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
policy=getattr(args, f"{prefix}policy"),
|
policy=getattr(args, f"{prefix}policy"),
|
||||||
|
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
||||||
|
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
||||||
worker_startup_timeout_secs=getattr(
|
worker_startup_timeout_secs=getattr(
|
||||||
args, f"{prefix}worker_startup_timeout_secs"
|
args, f"{prefix}worker_startup_timeout_secs"
|
||||||
),
|
),
|
||||||
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
if not router_args.decode_urls:
|
if not router_args.decode_urls:
|
||||||
raise ValueError("PD disaggregation mode requires --decode")
|
raise ValueError("PD disaggregation mode requires --decode")
|
||||||
|
|
||||||
|
# Warn about policy usage in PD mode
|
||||||
|
if (
|
||||||
|
router_args.prefill_policy
|
||||||
|
and router_args.decode_policy
|
||||||
|
and router_args.policy
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Both --prefill-policy and --decode-policy are specified. "
|
||||||
|
"The main --policy flag will be ignored for PD mode."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
router_args.prefill_policy
|
||||||
|
and not router_args.decode_policy
|
||||||
|
and router_args.policy
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
||||||
|
f"and --policy '{router_args.policy}' for decode nodes."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
router_args.decode_policy
|
||||||
|
and not router_args.prefill_policy
|
||||||
|
and router_args.policy
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Using --policy '{router_args.policy}' for prefill nodes "
|
||||||
|
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
||||||
|
)
|
||||||
|
|
||||||
# Create router with unified constructor
|
# Create router with unified constructor
|
||||||
router = Router(
|
router = Router(
|
||||||
worker_urls=(
|
worker_urls=(
|
||||||
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
decode_urls=(
|
decode_urls=(
|
||||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||||
),
|
),
|
||||||
|
prefill_policy=(
|
||||||
|
policy_from_str(router_args.prefill_policy)
|
||||||
|
if router_args.prefill_policy
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
decode_policy=(
|
||||||
|
policy_from_str(router_args.decode_policy)
|
||||||
|
if router_args.decode_policy
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
router.start()
|
router.start()
|
||||||
@@ -455,12 +512,18 @@ Examples:
|
|||||||
# Regular mode
|
# Regular mode
|
||||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
|
||||||
# PD disaggregated mode
|
# PD disaggregated mode with same policy for both
|
||||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||||
--policy cache_aware
|
--policy cache_aware
|
||||||
|
|
||||||
|
# PD mode with different policies for prefill and decode
|
||||||
|
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||||
|
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||||
|
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||||
|
--prefill-policy cache_aware --decode-policy power_of_two
|
||||||
|
|
||||||
""",
|
""",
|
||||||
formatter_class=CustomHelpFormatter,
|
formatter_class=CustomHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ class Router:
|
|||||||
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||||
decode_urls: List of URLs for decode servers (PD mode only)
|
decode_urls: List of URLs for decode servers (PD mode only)
|
||||||
|
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
||||||
|
If not specified, uses the main policy. Default: None
|
||||||
|
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
||||||
|
If not specified, uses the main policy. Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -79,6 +83,8 @@ class Router:
|
|||||||
pd_disaggregation: bool = False,
|
pd_disaggregation: bool = False,
|
||||||
prefill_urls: Optional[List[tuple]] = None,
|
prefill_urls: Optional[List[tuple]] = None,
|
||||||
decode_urls: Optional[List[str]] = None,
|
decode_urls: Optional[List[str]] = None,
|
||||||
|
prefill_policy: Optional[PolicyType] = None,
|
||||||
|
decode_policy: Optional[PolicyType] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
@@ -113,6 +119,8 @@ class Router:
|
|||||||
pd_disaggregation=pd_disaggregation,
|
pd_disaggregation=pd_disaggregation,
|
||||||
prefill_urls=prefill_urls,
|
prefill_urls=prefill_urls,
|
||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
|
prefill_policy=prefill_policy,
|
||||||
|
decode_policy=decode_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|||||||
@@ -46,6 +46,12 @@ pub enum RoutingMode {
|
|||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
/// Decode worker URLs
|
/// Decode worker URLs
|
||||||
decode_urls: Vec<String>,
|
decode_urls: Vec<String>,
|
||||||
|
/// Optional separate policy for prefill workers
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
prefill_policy: Option<PolicyConfig>,
|
||||||
|
/// Optional separate policy for decode workers
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
decode_policy: Option<PolicyConfig>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,9 +66,32 @@ impl RoutingMode {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
|
..
|
||||||
} => prefill_urls.len() + decode_urls.len(),
|
} => prefill_urls.len() + decode_urls.len(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the effective prefill policy for PD mode
|
||||||
|
/// Falls back to the main policy if no specific prefill policy is set
|
||||||
|
pub fn get_prefill_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
|
||||||
|
match self {
|
||||||
|
RoutingMode::PrefillDecode { prefill_policy, .. } => {
|
||||||
|
prefill_policy.as_ref().unwrap_or(main_policy)
|
||||||
|
}
|
||||||
|
_ => main_policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the effective decode policy for PD mode
|
||||||
|
/// Falls back to the main policy if no specific decode policy is set
|
||||||
|
pub fn get_decode_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
|
||||||
|
match self {
|
||||||
|
RoutingMode::PrefillDecode { decode_policy, .. } => {
|
||||||
|
decode_policy.as_ref().unwrap_or(main_policy)
|
||||||
|
}
|
||||||
|
_ => main_policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Policy configuration for routing
|
/// Policy configuration for routing
|
||||||
@@ -307,6 +336,8 @@ mod tests {
|
|||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
};
|
};
|
||||||
assert!(pd.is_pd_mode());
|
assert!(pd.is_pd_mode());
|
||||||
}
|
}
|
||||||
@@ -332,6 +363,8 @@ mod tests {
|
|||||||
"http://decode2".to_string(),
|
"http://decode2".to_string(),
|
||||||
"http://decode3".to_string(),
|
"http://decode3".to_string(),
|
||||||
],
|
],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
};
|
};
|
||||||
assert_eq!(pd.worker_count(), 5);
|
assert_eq!(pd.worker_count(), 5);
|
||||||
|
|
||||||
@@ -355,6 +388,8 @@ mod tests {
|
|||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&pd).unwrap();
|
let json = serde_json::to_string(&pd).unwrap();
|
||||||
assert!(json.contains("\"type\":\"prefill_decode\""));
|
assert!(json.contains("\"type\":\"prefill_decode\""));
|
||||||
@@ -551,6 +586,8 @@ mod tests {
|
|||||||
mode: RoutingMode::PrefillDecode {
|
mode: RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![],
|
prefill_urls: vec![],
|
||||||
decode_urls: vec![],
|
decode_urls: vec![],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -674,6 +711,8 @@ mod tests {
|
|||||||
"http://decode1:8000".to_string(),
|
"http://decode1:8000".to_string(),
|
||||||
"http://decode2:8000".to_string(),
|
"http://decode2:8000".to_string(),
|
||||||
],
|
],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
policy: PolicyConfig::PowerOfTwo {
|
policy: PolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 30,
|
load_check_interval_secs: 30,
|
||||||
@@ -800,4 +839,155 @@ mod tests {
|
|||||||
Some("production".to_string())
|
Some("production".to_string())
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Policy Fallback Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_policy_fallback_both_specified() {
|
||||||
|
// When both prefill and decode policies are specified, they should be used
|
||||||
|
let pd = RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
max_tree_size: 1000,
|
||||||
|
}),
|
||||||
|
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 60,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let main_policy = PolicyConfig::Random;
|
||||||
|
|
||||||
|
// Both specific policies should be used
|
||||||
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
|
PolicyConfig::CacheAware { .. } => {} // Success
|
||||||
|
_ => panic!("Expected CacheAware for prefill"),
|
||||||
|
}
|
||||||
|
|
||||||
|
match pd.get_decode_policy(&main_policy) {
|
||||||
|
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||||
|
_ => panic!("Expected PowerOfTwo for decode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_policy_fallback_only_prefill() {
|
||||||
|
// When only prefill policy is specified, decode should use main policy
|
||||||
|
let pd = RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
max_tree_size: 1000,
|
||||||
|
}),
|
||||||
|
decode_policy: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let main_policy = PolicyConfig::RoundRobin;
|
||||||
|
|
||||||
|
// Prefill should use specific policy
|
||||||
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
|
PolicyConfig::CacheAware { .. } => {} // Success
|
||||||
|
_ => panic!("Expected CacheAware for prefill"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode should fall back to main policy
|
||||||
|
match pd.get_decode_policy(&main_policy) {
|
||||||
|
PolicyConfig::RoundRobin => {} // Success
|
||||||
|
_ => panic!("Expected RoundRobin for decode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_policy_fallback_only_decode() {
|
||||||
|
// When only decode policy is specified, prefill should use main policy
|
||||||
|
let pd = RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 60,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let main_policy = PolicyConfig::Random;
|
||||||
|
|
||||||
|
// Prefill should fall back to main policy
|
||||||
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
|
PolicyConfig::Random => {} // Success
|
||||||
|
_ => panic!("Expected Random for prefill"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode should use specific policy
|
||||||
|
match pd.get_decode_policy(&main_policy) {
|
||||||
|
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||||
|
_ => panic!("Expected PowerOfTwo for decode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_policy_fallback_none_specified() {
|
||||||
|
// When no specific policies are specified, both should use main policy
|
||||||
|
let pd = RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let main_policy = PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 0.7,
|
||||||
|
balance_abs_threshold: 20,
|
||||||
|
balance_rel_threshold: 1.5,
|
||||||
|
eviction_interval_secs: 300,
|
||||||
|
max_tree_size: 2000,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Both should fall back to main policy
|
||||||
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold, ..
|
||||||
|
} => {
|
||||||
|
assert!((cache_threshold - 0.7).abs() < 0.0001);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected CacheAware for prefill"),
|
||||||
|
}
|
||||||
|
|
||||||
|
match pd.get_decode_policy(&main_policy) {
|
||||||
|
PolicyConfig::CacheAware {
|
||||||
|
cache_threshold, ..
|
||||||
|
} => {
|
||||||
|
assert!((cache_threshold - 0.7).abs() < 0.0001);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected CacheAware for decode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_regular_mode_policy_fallback() {
|
||||||
|
// For regular mode, the helper methods should just return the main policy
|
||||||
|
let regular = RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["http://worker1".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
let main_policy = PolicyConfig::RoundRobin;
|
||||||
|
|
||||||
|
// Both methods should return main policy for regular mode
|
||||||
|
match regular.get_prefill_policy(&main_policy) {
|
||||||
|
PolicyConfig::RoundRobin => {} // Success
|
||||||
|
_ => panic!("Expected RoundRobin for regular mode"),
|
||||||
|
}
|
||||||
|
|
||||||
|
match regular.get_decode_policy(&main_policy) {
|
||||||
|
PolicyConfig::RoundRobin => {} // Success
|
||||||
|
_ => panic!("Expected RoundRobin for regular mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ impl ConfigValidator {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
} => {
|
} => {
|
||||||
// Only require URLs if service discovery is disabled
|
// Only require URLs if service discovery is disabled
|
||||||
if !has_service_discovery {
|
if !has_service_discovery {
|
||||||
@@ -78,6 +80,14 @@ impl ConfigValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate optional prefill and decode policies
|
||||||
|
if let Some(p_policy) = prefill_policy {
|
||||||
|
Self::validate_policy(p_policy)?;
|
||||||
|
}
|
||||||
|
if let Some(d_policy) = decode_policy {
|
||||||
|
Self::validate_policy(d_policy)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -272,6 +282,35 @@ impl ConfigValidator {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For PD mode, validate that policies have sufficient workers
|
||||||
|
if let RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
|
} = &config.mode
|
||||||
|
{
|
||||||
|
// Check power-of-two for prefill
|
||||||
|
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
|
||||||
|
if prefill_urls.len() < 2 {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check power-of-two for decode
|
||||||
|
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
|
||||||
|
if decode_urls.len() < 2 {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason:
|
||||||
|
"Power-of-two policy for decode requires at least 2 decode workers"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -430,6 +469,8 @@ mod tests {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
||||||
decode_urls: vec!["http://decode:8000".to_string()],
|
decode_urls: vec!["http://decode:8000".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::Random,
|
PolicyConfig::Random,
|
||||||
);
|
);
|
||||||
@@ -444,6 +485,8 @@ mod tests {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode:8000".to_string()],
|
decode_urls: vec!["http://decode:8000".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::RoundRobin,
|
PolicyConfig::RoundRobin,
|
||||||
);
|
);
|
||||||
@@ -459,6 +502,8 @@ mod tests {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode:8000".to_string()],
|
decode_urls: vec!["http://decode:8000".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::CacheAware {
|
PolicyConfig::CacheAware {
|
||||||
cache_threshold: 0.5,
|
cache_threshold: 0.5,
|
||||||
@@ -491,4 +536,60 @@ mod tests {
|
|||||||
let result = ConfigValidator::validate(&config);
|
let result = ConfigValidator::validate(&config);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_pd_mode_with_separate_policies() {
|
||||||
|
// Test PD mode with different policies for prefill and decode
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![
|
||||||
|
("http://prefill1:8000".to_string(), None),
|
||||||
|
("http://prefill2:8000".to_string(), None),
|
||||||
|
],
|
||||||
|
decode_urls: vec![
|
||||||
|
"http://decode1:8000".to_string(),
|
||||||
|
"http://decode2:8000".to_string(),
|
||||||
|
],
|
||||||
|
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
max_tree_size: 1000,
|
||||||
|
}),
|
||||||
|
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 60,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
PolicyConfig::Random, // Main policy as fallback
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||||
|
// Test that power-of-two policy requires at least 2 workers
|
||||||
|
let config = RouterConfig::new(
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||||
|
decode_urls: vec![
|
||||||
|
"http://decode1:8000".to_string(),
|
||||||
|
"http://decode2:8000".to_string(),
|
||||||
|
],
|
||||||
|
prefill_policy: Some(PolicyConfig::PowerOfTwo {
|
||||||
|
load_check_interval_secs: 60,
|
||||||
|
}), // Requires 2+ workers
|
||||||
|
decode_policy: None,
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ struct Router {
|
|||||||
// PD-specific fields (only used when pd_disaggregation is true)
|
// PD-specific fields (only used when pd_disaggregation is true)
|
||||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
|
prefill_policy: Option<PolicyType>,
|
||||||
|
decode_policy: Option<PolicyType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
@@ -63,20 +65,9 @@ impl Router {
|
|||||||
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Determine routing mode
|
// Convert policy helper function
|
||||||
let mode = if self.pd_disaggregation {
|
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
|
||||||
RoutingMode::PrefillDecode {
|
match policy {
|
||||||
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
|
||||||
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
RoutingMode::Regular {
|
|
||||||
worker_urls: self.worker_urls.clone(),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convert policy
|
|
||||||
let policy = match self.policy {
|
|
||||||
PolicyType::Random => ConfigPolicyConfig::Random,
|
PolicyType::Random => ConfigPolicyConfig::Random,
|
||||||
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
|
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
|
||||||
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
|
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
|
||||||
@@ -89,8 +80,26 @@ impl Router {
|
|||||||
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 5, // Default value
|
load_check_interval_secs: 5, // Default value
|
||||||
},
|
},
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Determine routing mode
|
||||||
|
let mode = if self.pd_disaggregation {
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
||||||
|
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
||||||
|
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
|
||||||
|
decode_policy: self.decode_policy.as_ref().map(convert_policy),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: self.worker_urls.clone(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert main policy
|
||||||
|
let policy = convert_policy(&self.policy);
|
||||||
|
|
||||||
// Service discovery configuration
|
// Service discovery configuration
|
||||||
let discovery = if self.service_discovery {
|
let discovery = if self.service_discovery {
|
||||||
Some(DiscoveryConfig {
|
Some(DiscoveryConfig {
|
||||||
@@ -163,7 +172,9 @@ impl Router {
|
|||||||
request_timeout_secs = 600, // Add configurable request timeout
|
request_timeout_secs = 600, // Add configurable request timeout
|
||||||
pd_disaggregation = false, // New flag for PD mode
|
pd_disaggregation = false, // New flag for PD mode
|
||||||
prefill_urls = None,
|
prefill_urls = None,
|
||||||
decode_urls = None
|
decode_urls = None,
|
||||||
|
prefill_policy = None,
|
||||||
|
decode_policy = None
|
||||||
))]
|
))]
|
||||||
fn new(
|
fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
@@ -193,6 +204,8 @@ impl Router {
|
|||||||
pd_disaggregation: bool,
|
pd_disaggregation: bool,
|
||||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
|
prefill_policy: Option<PolicyType>,
|
||||||
|
decode_policy: Option<PolicyType>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
@@ -222,6 +235,8 @@ impl Router {
|
|||||||
pd_disaggregation,
|
pd_disaggregation,
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
|||||||
decode_workers: &[Box<dyn Worker>],
|
decode_workers: &[Box<dyn Worker>],
|
||||||
request_text: Option<&str>,
|
request_text: Option<&str>,
|
||||||
) -> Option<(usize, usize)> {
|
) -> Option<(usize, usize)> {
|
||||||
// In PD mode:
|
// DEPRECATED: This method is no longer used when separate policies are configured.
|
||||||
|
// The PD router now uses separate policies for prefill and decode selection.
|
||||||
|
// This implementation remains for backward compatibility when a single policy is used.
|
||||||
|
|
||||||
|
// In PD mode with single policy:
|
||||||
// - Prefill: Use cache-aware routing for better cache utilization
|
// - Prefill: Use cache-aware routing for better cache utilization
|
||||||
// - Decode: Use least-load routing for better load distribution
|
// - Decode: Use least-load routing for better load distribution
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,16 @@ impl RouterFactory {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
} => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config),
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
|
} => Self::create_pd_router(
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy.as_ref(),
|
||||||
|
decode_policy.as_ref(),
|
||||||
|
&config.policy,
|
||||||
|
config,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,18 +54,23 @@ impl RouterFactory {
|
|||||||
fn create_pd_router(
|
fn create_pd_router(
|
||||||
prefill_urls: &[(String, Option<u16>)],
|
prefill_urls: &[(String, Option<u16>)],
|
||||||
decode_urls: &[String],
|
decode_urls: &[String],
|
||||||
policy_config: &PolicyConfig,
|
prefill_policy_config: Option<&PolicyConfig>,
|
||||||
|
decode_policy_config: Option<&PolicyConfig>,
|
||||||
|
main_policy_config: &PolicyConfig,
|
||||||
router_config: &RouterConfig,
|
router_config: &RouterConfig,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// Create policy directly from PolicyConfig
|
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||||
// All policies now support PD mode through the select_worker_pair method
|
let prefill_policy =
|
||||||
let policy = PolicyFactory::create_from_config(policy_config);
|
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||||
|
let decode_policy =
|
||||||
|
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||||
|
|
||||||
// Create PD router with injected policy
|
// Create PD router with separate policies
|
||||||
let router = PDRouter::new(
|
let router = PDRouter::new(
|
||||||
prefill_urls.to_vec(),
|
prefill_urls.to_vec(),
|
||||||
decode_urls.to_vec(),
|
decode_urls.to_vec(),
|
||||||
policy,
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
router_config.worker_startup_timeout_secs,
|
router_config.worker_startup_timeout_secs,
|
||||||
router_config.worker_startup_check_interval_secs,
|
router_config.worker_startup_check_interval_secs,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@@ -22,8 +22,10 @@ use uuid::Uuid;
|
|||||||
pub struct PDRouter {
|
pub struct PDRouter {
|
||||||
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
pub policy: Arc<dyn LoadBalancingPolicy>,
|
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||||
|
pub decode_tree: Option<Arc<Mutex<Tree>>>,
|
||||||
pub timeout_secs: u64,
|
pub timeout_secs: u64,
|
||||||
pub interval_secs: u64,
|
pub interval_secs: u64,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
@@ -66,7 +68,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
workers.push(worker);
|
workers.push(worker);
|
||||||
|
|
||||||
// Add to cache tree if using cache-aware policy
|
// Add to cache tree if using cache-aware policy for prefill
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
tree.lock().unwrap().insert("", &url);
|
tree.lock().unwrap().insert("", &url);
|
||||||
}
|
}
|
||||||
@@ -102,6 +104,11 @@ impl PDRouter {
|
|||||||
|
|
||||||
workers.push(worker);
|
workers.push(worker);
|
||||||
|
|
||||||
|
// Add to cache tree if using cache-aware policy for decode
|
||||||
|
if let Some(ref tree) = self.decode_tree {
|
||||||
|
tree.lock().unwrap().insert("", &url);
|
||||||
|
}
|
||||||
|
|
||||||
info!("Added decode server: {}", url);
|
info!("Added decode server: {}", url);
|
||||||
Ok(format!("Successfully added decode server: {}", url))
|
Ok(format!("Successfully added decode server: {}", url))
|
||||||
}
|
}
|
||||||
@@ -126,12 +133,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Remove from cache tree if using cache-aware policy
|
// Remove from cache tree if using cache-aware policy
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
// Note: Tree doesn't have a remove method, so we rebuild it
|
tree.lock().unwrap().remove_tenant(url);
|
||||||
let mut tree_guard = tree.lock().unwrap();
|
|
||||||
*tree_guard = Tree::new();
|
|
||||||
for worker in workers.iter() {
|
|
||||||
tree_guard.insert("", worker.url());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Removed prefill server: {}", url);
|
info!("Removed prefill server: {}", url);
|
||||||
@@ -156,6 +158,11 @@ impl PDRouter {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove from the cache tree if using cache-aware policy for decode
|
||||||
|
if let Some(ref tree) = self.decode_tree {
|
||||||
|
tree.lock().unwrap().remove_tenant(url);
|
||||||
|
}
|
||||||
|
|
||||||
info!("Removed decode server: {}", url);
|
info!("Removed decode server: {}", url);
|
||||||
Ok(format!("Successfully removed decode server: {}", url))
|
Ok(format!("Successfully removed decode server: {}", url))
|
||||||
}
|
}
|
||||||
@@ -163,7 +170,8 @@ impl PDRouter {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
decode_urls: Vec<String>,
|
decode_urls: Vec<String>,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
@@ -192,10 +200,10 @@ impl PDRouter {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize cache-aware components if needed
|
// Initialize cache-aware components if needed for prefill policy
|
||||||
let prefill_tree = if policy.name() == "cache_aware" {
|
let prefill_tree = if prefill_policy.name() == "cache_aware" {
|
||||||
// Initialize the policy's internal tree with prefill workers
|
// Initialize the policy's internal tree with prefill workers
|
||||||
if let Some(cache_policy) = policy
|
if let Some(cache_policy) = prefill_policy
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
{
|
{
|
||||||
@@ -212,6 +220,26 @@ impl PDRouter {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Initialize cache-aware components if needed for decode policy
|
||||||
|
let decode_tree = if decode_policy.name() == "cache_aware" {
|
||||||
|
// Initialize the policy's internal tree with decode workers
|
||||||
|
if let Some(cache_policy) = decode_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.init_workers(&decode_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
|
// Initialize tree with decode workers
|
||||||
|
for worker in &decode_workers {
|
||||||
|
tree.lock().unwrap().insert("", worker.url());
|
||||||
|
}
|
||||||
|
Some(tree)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// Set up background load monitoring for power-of-two selection
|
// Set up background load monitoring for power-of-two selection
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
let worker_loads = Arc::new(rx);
|
let worker_loads = Arc::new(rx);
|
||||||
@@ -222,11 +250,13 @@ impl PDRouter {
|
|||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
let load_monitor_handle = if policy.name() == "power_of_two" {
|
let load_monitor_handle =
|
||||||
|
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||||
let monitor_urls = all_urls.clone();
|
let monitor_urls = all_urls.clone();
|
||||||
let monitor_interval = interval_secs;
|
let monitor_interval = interval_secs;
|
||||||
let monitor_client = http_client.clone();
|
let monitor_client = http_client.clone();
|
||||||
let policy_clone = Arc::clone(&policy);
|
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||||
|
let decode_policy_clone = Arc::clone(&decode_policy);
|
||||||
|
|
||||||
Some(Arc::new(tokio::spawn(async move {
|
Some(Arc::new(tokio::spawn(async move {
|
||||||
Self::monitor_worker_loads_with_client(
|
Self::monitor_worker_loads_with_client(
|
||||||
@@ -234,7 +264,8 @@ impl PDRouter {
|
|||||||
tx,
|
tx,
|
||||||
monitor_interval,
|
monitor_interval,
|
||||||
monitor_client,
|
monitor_client,
|
||||||
policy_clone,
|
prefill_policy_clone,
|
||||||
|
decode_policy_clone,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
})))
|
})))
|
||||||
@@ -254,8 +285,10 @@ impl PDRouter {
|
|||||||
Ok(PDRouter {
|
Ok(PDRouter {
|
||||||
prefill_workers,
|
prefill_workers,
|
||||||
decode_workers,
|
decode_workers,
|
||||||
policy,
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
prefill_tree,
|
prefill_tree,
|
||||||
|
decode_tree,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
worker_loads,
|
worker_loads,
|
||||||
@@ -736,19 +769,22 @@ impl PDRouter {
|
|||||||
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
|
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the policy to select worker pair
|
// Select prefill worker using prefill policy
|
||||||
match self
|
let prefill_idx = self
|
||||||
.policy
|
.prefill_policy
|
||||||
.select_worker_pair(&prefill_workers, &decode_workers, request_text)
|
.select_worker(&prefill_workers, request_text)
|
||||||
{
|
.ok_or("Failed to select prefill worker")?;
|
||||||
Some((prefill_idx, decode_idx)) => {
|
|
||||||
|
// Select decode worker using decode policy
|
||||||
|
let decode_idx = self
|
||||||
|
.decode_policy
|
||||||
|
.select_worker(&decode_workers, request_text)
|
||||||
|
.ok_or("Failed to select decode worker")?;
|
||||||
|
|
||||||
let prefill = prefill_workers[prefill_idx].clone_worker();
|
let prefill = prefill_workers[prefill_idx].clone_worker();
|
||||||
let decode = decode_workers[decode_idx].clone_worker();
|
let decode = decode_workers[decode_idx].clone_worker();
|
||||||
Ok((prefill, decode))
|
Ok((prefill, decode))
|
||||||
}
|
}
|
||||||
None => Err("Failed to select worker pair".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Background task to monitor worker loads with shared client
|
// Background task to monitor worker loads with shared client
|
||||||
async fn monitor_worker_loads_with_client(
|
async fn monitor_worker_loads_with_client(
|
||||||
@@ -756,7 +792,8 @@ impl PDRouter {
|
|||||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
let mut loads = HashMap::new();
|
let mut loads = HashMap::new();
|
||||||
@@ -781,8 +818,9 @@ impl PDRouter {
|
|||||||
|
|
||||||
debug!("Worker loads updated: {:?}", loads);
|
debug!("Worker loads updated: {:?}", loads);
|
||||||
|
|
||||||
// Update the policy with current loads
|
// Update both policies with current loads
|
||||||
policy.update_loads(&loads);
|
prefill_policy.update_loads(&loads);
|
||||||
|
decode_policy.update_loads(&loads);
|
||||||
|
|
||||||
// Check if receiver is still active
|
// Check if receiver is still active
|
||||||
if tx.send(loads).is_err() {
|
if tx.send(loads).is_err() {
|
||||||
@@ -1463,13 +1501,16 @@ mod tests {
|
|||||||
use actix_web::test::TestRequest;
|
use actix_web::test::TestRequest;
|
||||||
|
|
||||||
fn create_test_pd_router() -> PDRouter {
|
fn create_test_pd_router() -> PDRouter {
|
||||||
let policy = Arc::new(RandomPolicy::new());
|
let prefill_policy = Arc::new(RandomPolicy::new());
|
||||||
|
let decode_policy = Arc::new(RandomPolicy::new());
|
||||||
|
|
||||||
PDRouter {
|
PDRouter {
|
||||||
prefill_workers: Arc::new(RwLock::new(vec![])),
|
prefill_workers: Arc::new(RwLock::new(vec![])),
|
||||||
decode_workers: Arc::new(RwLock::new(vec![])),
|
decode_workers: Arc::new(RwLock::new(vec![])),
|
||||||
policy,
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
prefill_tree: None,
|
prefill_tree: None,
|
||||||
|
decode_tree: None,
|
||||||
timeout_secs: 5,
|
timeout_secs: 5,
|
||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||||
@@ -1608,9 +1649,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_cache_tree_operations() {
|
async fn test_cache_tree_operations() {
|
||||||
let policy = Arc::new(CacheAwarePolicy::new());
|
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||||
let mut router = create_test_pd_router();
|
let mut router = create_test_pd_router();
|
||||||
router.policy = policy;
|
router.prefill_policy = cache_policy;
|
||||||
|
|
||||||
// Initialize cache tree
|
// Initialize cache tree
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
@@ -1638,9 +1679,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_cache_tree_rebuild_on_remove() {
|
async fn test_cache_tree_rebuild_on_remove() {
|
||||||
let policy = Arc::new(CacheAwarePolicy::new());
|
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||||
let mut router = create_test_pd_router();
|
let mut router = create_test_pd_router();
|
||||||
router.policy = policy;
|
router.prefill_policy = cache_policy;
|
||||||
|
|
||||||
// Initialize cache tree
|
// Initialize cache tree
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
@@ -1880,9 +1921,10 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_load_monitor_updates() {
|
async fn test_load_monitor_updates() {
|
||||||
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
||||||
let mut router = create_test_pd_router();
|
let mut router = create_test_pd_router();
|
||||||
router.policy = policy;
|
router.prefill_policy = power_of_two_policy.clone();
|
||||||
|
router.decode_policy = power_of_two_policy;
|
||||||
|
|
||||||
// Create load channel
|
// Create load channel
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
|
|||||||
@@ -122,6 +122,8 @@ mod test_pd_routing {
|
|||||||
"http://decode1:8080".to_string(),
|
"http://decode1:8080".to_string(),
|
||||||
"http://decode2:8080".to_string(),
|
"http://decode2:8080".to_string(),
|
||||||
],
|
],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::Random,
|
PolicyConfig::Random,
|
||||||
),
|
),
|
||||||
@@ -129,6 +131,8 @@ mod test_pd_routing {
|
|||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||||
decode_urls: vec!["http://decode:8080".to_string()],
|
decode_urls: vec!["http://decode:8080".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::PowerOfTwo {
|
PolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 5,
|
load_check_interval_secs: 5,
|
||||||
@@ -142,6 +146,8 @@ mod test_pd_routing {
|
|||||||
("http://p3:8080".to_string(), Some(9002)),
|
("http://p3:8080".to_string(), Some(9002)),
|
||||||
],
|
],
|
||||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||||
|
prefill_policy: None,
|
||||||
|
decode_policy: None,
|
||||||
},
|
},
|
||||||
PolicyConfig::CacheAware {
|
PolicyConfig::CacheAware {
|
||||||
cache_threshold: 0.7,
|
cache_threshold: 0.7,
|
||||||
|
|||||||
Reference in New Issue
Block a user