[misc] Add PD service discovery support in router (#7361)
This commit is contained in:
@@ -1045,6 +1045,55 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a worker with PD mode support
|
||||
pub async fn add_pd_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
pod_type: crate::service_discovery::PodType,
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, String> {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => match pod_type {
|
||||
crate::service_discovery::PodType::Prefill => pd_router
|
||||
.add_prefill_server(worker_url.to_string(), bootstrap_port)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
crate::service_discovery::PodType::Decode => pd_router
|
||||
.add_decode_server(worker_url.to_string())
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
crate::service_discovery::PodType::Regular => {
|
||||
Err("Regular pod type not supported in PD mode".to_string())
|
||||
}
|
||||
},
|
||||
_ => Err("add_pd_worker only supported in PD mode".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker with PD mode support
|
||||
pub async fn remove_pd_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
pod_type: crate::service_discovery::PodType,
|
||||
) -> Result<String, String> {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => match pod_type {
|
||||
crate::service_discovery::PodType::Prefill => pd_router
|
||||
.remove_prefill_server(worker_url)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
crate::service_discovery::PodType::Decode => pd_router
|
||||
.remove_decode_server(worker_url)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
crate::service_discovery::PodType::Regular => {
|
||||
Err("Regular pod type not supported in PD mode".to_string())
|
||||
}
|
||||
},
|
||||
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
@@ -1174,3 +1223,108 @@ impl Router {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::service_discovery::PodType;
|
||||
|
||||
fn create_test_regular_router() -> Router {
|
||||
Router::Random {
|
||||
worker_urls: Arc::new(RwLock::new(vec![
|
||||
"http://worker1:8080".to_string(),
|
||||
"http://worker2:8080".to_string(),
|
||||
])),
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router_get_worker_urls_regular() {
|
||||
let router = create_test_regular_router();
|
||||
let worker_urls = router.get_worker_urls();
|
||||
let urls = worker_urls.read().unwrap();
|
||||
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_router_get_worker_urls_pd_mode() {
|
||||
// // For PD mode, get_worker_urls returns empty list
|
||||
// // Note: PDRouter::new requires health checks which fail in tests
|
||||
// // This test would need a mock server or different test setup
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_pd_worker_with_regular_router() {
|
||||
let router = create_test_regular_router();
|
||||
|
||||
let result = router
|
||||
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("add_pd_worker only supported in PD mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_pd_worker_with_regular_router() {
|
||||
let router = create_test_regular_router();
|
||||
|
||||
let result = router
|
||||
.remove_pd_worker("http://worker:8080", PodType::Decode)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("remove_pd_worker only supported in PD mode"));
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_add_pd_worker_with_pd_router_regular_type() {
|
||||
// // Note: PDRouter::new requires health checks which fail in tests
|
||||
// // This test would need a mock server or different test setup
|
||||
// }
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
|
||||
// // Note: PDRouter::new requires health checks which fail in tests
|
||||
// // This test would need a mock server or different test setup
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_select_first_worker_regular() {
|
||||
let router = create_test_regular_router();
|
||||
let result = router.select_first_worker();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_select_first_worker_pd_mode() {
|
||||
// // Note: PDRouter::new requires health checks which fail in tests
|
||||
// // This test would need a mock server or different test setup
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_wait_for_healthy_workers_empty_list() {
|
||||
let result = Router::wait_for_healthy_workers(&[], 1, 1);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wait_for_healthy_workers_invalid_urls() {
|
||||
// This test will timeout quickly since the URLs are invalid
|
||||
let result =
|
||||
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Timeout"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user