[router] add different policies for p node and d node (#8395)

This commit is contained in:
Simo Lin
2025-07-27 00:39:20 -07:00
committed by GitHub
parent 0bcc195f4e
commit 2ab97023e3
10 changed files with 536 additions and 81 deletions

View File

@@ -54,6 +54,8 @@ struct Router {
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
}
impl Router {
@@ -63,11 +65,31 @@ impl Router {
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
};
// Convert policy helper function
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
match policy {
PolicyType::Random => ConfigPolicyConfig::Random,
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
}
};
// Determine routing mode
let mode = if self.pd_disaggregation {
RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
decode_urls: self.decode_urls.clone().unwrap_or_default(),
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
decode_policy: self.decode_policy.as_ref().map(convert_policy),
}
} else {
RoutingMode::Regular {
@@ -75,21 +97,8 @@ impl Router {
}
};
// Convert policy
let policy = match self.policy {
PolicyType::Random => ConfigPolicyConfig::Random,
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
};
// Convert main policy
let policy = convert_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
@@ -163,7 +172,9 @@ impl Router {
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregation = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
decode_urls = None,
prefill_policy = None,
decode_policy = None
))]
fn new(
worker_urls: Vec<String>,
@@ -193,6 +204,8 @@ impl Router {
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
) -> PyResult<Self> {
Ok(Router {
host,
@@ -222,6 +235,8 @@ impl Router {
pd_disaggregation,
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
})
}