[misc] Add PD service discovery support in router (#7361)

This commit is contained in:
Simo Lin
2025-06-22 17:54:14 -07:00
committed by GitHub
parent bd4f581896
commit 30f2a44a96
11 changed files with 1362 additions and 120 deletions

View File

@@ -42,12 +42,16 @@ struct Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
// PD service discovery fields
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
}
@@ -74,10 +78,13 @@ impl Router {
selector = HashMap::new(),
service_discovery_port = 80,
service_discovery_namespace = None,
prefill_selector = HashMap::new(),
decode_selector = HashMap::new(),
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
prometheus_port = None,
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
pd_disaggregation = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))]
@@ -100,10 +107,13 @@ impl Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> {
@@ -126,17 +136,20 @@ impl Router {
selector,
service_discovery_port,
service_discovery_namespace,
prefill_selector,
decode_selector,
bootstrap_port_annotation,
prometheus_port,
prometheus_host,
request_timeout_secs,
pd_disaggregated,
pd_disaggregation,
prefill_urls,
decode_urls,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = if self.pd_disaggregated {
let policy_config = if self.pd_disaggregation {
// PD mode - map PolicyType to PDSelectionPolicy
let pd_selection_policy = match &self.policy {
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
@@ -207,6 +220,11 @@ impl Router {
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
// PD mode configuration
pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
})
} else {
None

View File

@@ -1,7 +1,9 @@
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy};
use crate::pd_types::{
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
};
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
@@ -65,12 +67,145 @@ impl Drop for LoadGuard<'_> {
}
impl PDRouter {
// TODO: Add methods for dynamic worker management to support /register endpoint:
// - add_prefill_server(url: String, bootstrap_port: Option<u16>)
// - add_decode_server(url: String)
// - remove_prefill_server(url: &str)
// - remove_decode_server(url: &str)
// These methods will be used when service discovery is implemented for PD mode
// Dynamic worker management methods for service discovery
pub async fn add_prefill_server(
&self,
url: String,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Create EngineInfo for the new prefill server
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Add to prefill workers list
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
// Add to cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().insert("", &url);
}
info!("Added prefill server: {}", url);
Ok(format!("Successfully added prefill server: {}", url))
}
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Create EngineInfo for the new decode server
let engine_info = EngineInfo::new_decode(url.clone());
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Add to decode workers list
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
}
pub async fn remove_prefill_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
// Remove from load tracking
self.load_tracking.remove(url);
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
// Note: Tree doesn't have a remove method, so we rebuild it
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", &worker.url);
}
}
info!("Removed prefill server: {}", url);
Ok(format!("Successfully removed prefill server: {}", url))
}
pub async fn remove_decode_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
// Remove from load tracking
self.load_tracking.remove(url);
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
}
pub fn new(
prefill_urls: Vec<(String, Option<u16>)>,

View File

@@ -3,6 +3,31 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
#[error("Worker already exists: {url}")]
WorkerAlreadyExists { url: String },
#[error("Worker not found: {url}")]
WorkerNotFound { url: String },
#[error("Lock acquisition failed: {operation}")]
LockError { operation: String },
#[error("Health check failed for worker: {url}")]
HealthCheckFailed { url: String },
#[error("Invalid worker configuration: {reason}")]
InvalidConfiguration { reason: String },
#[error("Network error: {message}")]
NetworkError { message: String },
#[error("Timeout waiting for worker: {url}")]
Timeout { url: String },
}
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,

View File

@@ -1045,6 +1045,55 @@ impl Router {
}
}
/// Add a worker with PD mode support
pub async fn add_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
bootstrap_port: Option<u16>,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.add_prefill_server(worker_url.to_string(), bootstrap_port)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.add_decode_server(worker_url.to_string())
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("add_pd_worker only supported in PD mode".to_string()),
}
}
/// Remove a worker with PD mode support
pub async fn remove_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.remove_prefill_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.remove_decode_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
@@ -1174,3 +1223,108 @@ impl Router {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router {
Router::Random {
worker_urls: Arc::new(RwLock::new(vec![
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
timeout_secs: 5,
interval_secs: 1,
}
}
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let worker_urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
}
// #[test]
// fn test_router_get_worker_urls_pd_mode() {
// // For PD mode, get_worker_urls returns empty list
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[tokio::test]
async fn test_add_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("add_pd_worker only supported in PD mode"));
}
#[tokio::test]
async fn test_remove_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.remove_pd_worker("http://worker:8080", PodType::Decode)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("remove_pd_worker only supported in PD mode"));
}
// #[tokio::test]
// async fn test_add_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
// #[tokio::test]
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_select_first_worker_regular() {
let router = create_test_regular_router();
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
}
// #[test]
// fn test_select_first_worker_pd_mode() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_wait_for_healthy_workers_empty_list() {
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
}
#[test]
fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
}

File diff suppressed because it is too large Load Diff