Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
use pyo3::prelude::*;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod openai_api_types;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod prometheus;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
@@ -14,6 +18,7 @@ pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
PowerOfTwo, // Moved from PD-specific, now shared
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -39,6 +44,12 @@ struct Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
// PD mode flag
|
||||
pd_disaggregated: bool,
|
||||
// PD-specific fields (only used when pd_disaggregated is true)
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -56,7 +67,7 @@ impl Router {
|
||||
balance_rel_threshold = 1.0001,
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 4 * 1024 * 1024,
|
||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||
verbose = false,
|
||||
log_dir = None,
|
||||
service_discovery = false,
|
||||
@@ -64,7 +75,11 @@ impl Router {
|
||||
service_discovery_port = 80,
|
||||
service_discovery_namespace = None,
|
||||
prometheus_port = None,
|
||||
prometheus_host = None
|
||||
prometheus_host = None,
|
||||
request_timeout_secs = 600, // Add configurable request timeout
|
||||
pd_disaggregated = false, // New flag for PD mode
|
||||
prefill_urls = None,
|
||||
decode_urls = None
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
@@ -87,6 +102,10 @@ impl Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
pd_disaggregated: bool,
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -109,28 +128,75 @@ impl Router {
|
||||
service_discovery_namespace,
|
||||
prometheus_port,
|
||||
prometheus_host,
|
||||
request_timeout_secs,
|
||||
pd_disaggregated,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
let policy_config = if self.pd_disaggregated {
|
||||
// PD mode - map PolicyType to PDSelectionPolicy
|
||||
let pd_selection_policy = match &self.policy {
|
||||
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
||||
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
|
||||
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
},
|
||||
PolicyType::RoundRobin => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires prefill_urls",
|
||||
)
|
||||
})?;
|
||||
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires decode_urls",
|
||||
)
|
||||
})?;
|
||||
|
||||
router::PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: pd_selection_policy,
|
||||
prefill_urls: prefill_urls.clone(),
|
||||
decode_urls: decode_urls.clone(),
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Regular mode
|
||||
match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
PolicyType::PowerOfTwo => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create service discovery config if enabled
|
||||
@@ -166,6 +232,7 @@ impl Router {
|
||||
log_dir: self.log_dir.clone(),
|
||||
service_discovery_config,
|
||||
prometheus_config,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||
|
||||
Reference in New Issue
Block a user