[misc] Add PD service discovery support in router (#7361)

This commit is contained in:
Simo Lin
2025-06-22 17:54:14 -07:00
committed by GitHub
parent bd4f581896
commit 30f2a44a96
11 changed files with 1362 additions and 120 deletions

View File

@@ -1045,6 +1045,55 @@ impl Router {
}
}
/// Add a worker with PD mode support
pub async fn add_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
bootstrap_port: Option<u16>,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.add_prefill_server(worker_url.to_string(), bootstrap_port)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.add_decode_server(worker_url.to_string())
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("add_pd_worker only supported in PD mode".to_string()),
}
}
/// Remove a worker with PD mode support
pub async fn remove_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.remove_prefill_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.remove_decode_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
@@ -1174,3 +1223,108 @@ impl Router {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router {
Router::Random {
worker_urls: Arc::new(RwLock::new(vec![
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
timeout_secs: 5,
interval_secs: 1,
}
}
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let worker_urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
}
// #[test]
// fn test_router_get_worker_urls_pd_mode() {
// // For PD mode, get_worker_urls returns empty list
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[tokio::test]
async fn test_add_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("add_pd_worker only supported in PD mode"));
}
#[tokio::test]
async fn test_remove_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.remove_pd_worker("http://worker:8080", PodType::Decode)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("remove_pd_worker only supported in PD mode"));
}
// #[tokio::test]
// async fn test_add_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
// #[tokio::test]
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_select_first_worker_regular() {
let router = create_test_regular_router();
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
}
// #[test]
// fn test_select_first_worker_pd_mode() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_wait_for_healthy_workers_empty_list() {
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
}
#[test]
fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
}