Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

This commit is contained in:
Simo Lin
2025-06-18 11:28:15 -07:00
committed by GitHub
parent 712bf9ec9b
commit 09ae5b20f3
13 changed files with 4045 additions and 187 deletions

View File

@@ -1,7 +1,11 @@
use pyo3::prelude::*;
pub mod logging;
use std::collections::HashMap;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
pub mod prometheus;
pub mod request_adapter;
pub mod router;
pub mod server;
pub mod service_discovery;
@@ -14,6 +18,7 @@ pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
}
#[pyclass]
@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
}
#[pymethods]
@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
verbose = false,
log_dir = None,
service_discovery = false,
@@ -64,7 +75,11 @@ impl Router {
service_discovery_port = 80,
service_discovery_namespace = None,
prometheus_port = None,
prometheus_host = None
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))]
fn new(
worker_urls: Vec<String>,
@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> {
Ok(Router {
host,
@@ -109,28 +128,75 @@ impl Router {
service_discovery_namespace,
prometheus_port,
prometheus_host,
request_timeout_secs,
pd_disaggregated,
prefill_urls,
decode_urls,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
let policy_config = if self.pd_disaggregated {
// PD mode - map PolicyType to PDSelectionPolicy
let pd_selection_policy = match &self.policy {
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
},
PolicyType::RoundRobin => {
return Err(pyo3::exceptions::PyValueError::new_err(
"RoundRobin policy is not supported in PD disaggregated mode",
));
}
};
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires prefill_urls",
)
})?;
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires decode_urls",
)
})?;
router::PolicyConfig::PrefillDecodeConfig {
selection_policy: pd_selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
}
} else {
// Regular mode
match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => {
return Err(pyo3::exceptions::PyValueError::new_err(
"PowerOfTwo policy is only supported in PD disaggregated mode",
));
}
}
};
// Create service discovery config if enabled
@@ -166,6 +232,7 @@ impl Router {
log_dir: self.log_dir.clone(),
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;