From 5c8365a0516ae908c1733054afb6852f3bee91dd Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 20 Jul 2025 23:12:52 -0700 Subject: [PATCH] [router] add ut for pd router (#8208) --- sgl-router/src/routers/pd_router.rs | 512 ++++++++++++++++++++++++++++ sgl-router/tests/test_pd_routing.rs | 21 -- 2 files changed, 512 insertions(+), 21 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index d156c9f34..7c70a3873 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1393,3 +1393,515 @@ impl RouterTrait for PDRouter { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + use crate::policies::{CacheAwarePolicy, RandomPolicy}; + use crate::routers::pd_types::SingleOrBatch; + use actix_web::test::TestRequest; + + fn create_test_pd_router() -> PDRouter { + let policy = Arc::new(RandomPolicy::new()); + + PDRouter { + prefill_workers: Arc::new(RwLock::new(vec![])), + decode_workers: Arc::new(RwLock::new(vec![])), + policy, + prefill_tree: None, + timeout_secs: 5, + interval_secs: 1, + worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), + load_monitor_handle: None, + http_client: reqwest::Client::new(), + _prefill_health_checker: None, + _decode_health_checker: None, + } + } + + fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box { + let worker = BasicWorker::new(url, worker_type); + worker.set_healthy(healthy); + Box::new(worker) + } + + // ============= Worker Management Tests ============= + + #[tokio::test] + async fn test_add_prefill_server_already_exists() { + let router = create_test_pd_router(); + + // Add a worker first + let worker = create_test_worker( + "http://localhost:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Try to add the same URL again - this would fail during health check in real scenario + // For unit test, we test the duplicate check logic + let workers = router.prefill_workers.read().unwrap(); + let exists = workers.iter().any(|w| w.url() == "http://localhost:8000"); + assert!(exists); + } + + #[tokio::test] + async fn test_remove_prefill_server_success() { + let router = create_test_pd_router(); + + // Add servers first + let worker1 = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let worker2 = create_test_worker( + "http://worker2".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + router.prefill_workers.write().unwrap().push(worker1); + router.prefill_workers.write().unwrap().push(worker2); + + // Remove one + let result = router.remove_prefill_server("http://worker1").await; + + assert!(result.is_ok()); + assert!(result.unwrap().contains("Successfully removed")); + + let workers = router.prefill_workers.read().unwrap(); + assert_eq!(workers.len(), 1); + assert_eq!(workers[0].url(), "http://worker2"); + } + + #[tokio::test] + async fn test_remove_prefill_server_not_found() { + let router = create_test_pd_router(); + + let result = router.remove_prefill_server("http://nonexistent").await; + + assert!(result.is_err()); + match result.unwrap_err() { + PDRouterError::WorkerNotFound { url } => { + assert_eq!(url, "http://nonexistent"); + } + _ => panic!("Expected WorkerNotFound error"), + } + } + + #[tokio::test] + async fn test_remove_decode_server_success() { + let router = create_test_pd_router(); + + // Add server first + let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true); + router.decode_workers.write().unwrap().push(worker); + + let result = router.remove_decode_server("http://decode1").await; + + assert!(result.is_ok()); + assert!(result.unwrap().contains("Successfully removed")); + + let workers = router.decode_workers.read().unwrap(); + assert_eq!(workers.len(), 0); + } + + // ============= Lock Error Handling Tests ============= + + #[test] + fn test_lock_operations() { + let router = create_test_pd_router(); + + // Test read/write locks work correctly + { + let read_guard = router.prefill_workers.read().unwrap(); + assert_eq!(read_guard.len(), 0); + } + + { + let mut write_guard = router.prefill_workers.write().unwrap(); + write_guard.push(create_test_worker( + "http://test".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + )); + } + + { + let read_guard = router.prefill_workers.read().unwrap(); + assert_eq!(read_guard.len(), 1); + } + } + + // ============= Cache Tree Integration Tests ============= + + #[tokio::test] + async fn test_cache_tree_operations() { + let policy = Arc::new(CacheAwarePolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Initialize cache tree + let tree = Arc::new(Mutex::new(Tree::new())); + router.prefill_tree = Some(Arc::clone(&tree)); + + // Manually add worker and update tree + let worker = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Update tree + tree.lock().unwrap().insert("", "http://worker1"); + + // Verify tree contains the worker + let tree_guard = tree.lock().unwrap(); + let (_matched_text, tenant) = tree_guard.prefix_match(""); + // Since we inserted with empty prefix, we should get a match + assert_eq!(tenant, "http://worker1"); + } + + #[tokio::test] + async fn test_cache_tree_rebuild_on_remove() { + let policy = Arc::new(CacheAwarePolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Initialize cache tree + let tree = Arc::new(Mutex::new(Tree::new())); + router.prefill_tree = Some(Arc::clone(&tree)); + + // Add multiple workers + let worker1 = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let worker2 = create_test_worker( + "http://worker2".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + + router.prefill_workers.write().unwrap().push(worker1); + router.prefill_workers.write().unwrap().push(worker2); + + // Initialize tree with both workers + { + let tree_guard = tree.lock().unwrap(); + tree_guard.insert("", "http://worker1"); + tree_guard.insert("", "http://worker2"); + } + + // Remove one worker + let result = router.remove_prefill_server("http://worker1").await; + assert!(result.is_ok()); + + // Verify tree only contains remaining worker + let tree_guard = tree.lock().unwrap(); + let (_matched_text, tenant) = tree_guard.prefix_match(""); + // After rebuild, tree should only have worker2 + assert_eq!(tenant, "http://worker2"); + } + + #[tokio::test] + async fn test_no_cache_tree_operations() { + let router = create_test_pd_router(); + assert!(router.prefill_tree.is_none()); + + // Add a worker without cache tree + let worker = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Remove should work without tree + let result = router.remove_prefill_server("http://worker1").await; + assert!(result.is_ok()); + } + + // ============= Bootstrap Injection Tests ============= + + #[test] + fn test_bootstrap_injection_with_existing_fields() { + let mut req = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())), + bootstrap_port: Some(SingleOrBatch::Single(Some(9999))), + bootstrap_room: Some(SingleOrBatch::Single(12345)), + other: Value::Object(serde_json::Map::new()), + }; + + let prefill_worker = create_test_worker( + "http://new-host:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + // Bootstrap info is added regardless of existing fields + let result = req.add_bootstrap_info(prefill_worker.as_ref()); + assert!(result.is_ok()); + + // Bootstrap info should be updated with new values + assert_eq!( + req.bootstrap_host, + Some(SingleOrBatch::Single("new-host".to_string())) + ); + assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080)))); + // Room should be regenerated (different from original) + if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room { + assert_ne!(room, 12345); + } else { + panic!("Expected single room ID"); + } + } + + #[test] + fn test_bootstrap_room_generation() { + let mut req1 = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(serde_json::Map::new()), + }; + + let mut req2 = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(serde_json::Map::new()), + }; + + let prefill_worker = create_test_worker( + "http://host:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + // Add bootstrap info to both requests + let _ = req1.add_bootstrap_info(prefill_worker.as_ref()); + let _ = req2.add_bootstrap_info(prefill_worker.as_ref()); + + // Room IDs should be different + if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) = + (req1.bootstrap_room, req2.bootstrap_room) + { + assert_ne!(room1, room2, "Room IDs should be unique"); + } else { + panic!("Expected single room IDs"); + } + } + + // ============= Worker Selection Tests ============= + + #[tokio::test] + async fn test_select_healthy_prefill_worker() { + let router = create_test_pd_router(); + + // Add mix of healthy and unhealthy workers + let healthy_worker = create_test_worker( + "http://healthy".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let unhealthy_worker = create_test_worker( + "http://unhealthy".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + false, + ); + let decode_worker = + create_test_worker("http://decode".to_string(), WorkerType::Decode, true); + + router + .prefill_workers + .write() + .unwrap() + .push(unhealthy_worker); + router.prefill_workers.write().unwrap().push(healthy_worker); + router.decode_workers.write().unwrap().push(decode_worker); + + let client = reqwest::Client::new(); + let result = router.select_pd_pair(&client, None).await; + + assert!(result.is_ok()); + let (prefill, _decode) = result.unwrap(); + + // Should select the healthy worker + assert_eq!(prefill.url(), "http://healthy"); + assert!(prefill.is_healthy()); + } + + #[tokio::test] + async fn test_empty_worker_lists() { + let router = create_test_pd_router(); + + let client = reqwest::Client::new(); + let result = router.select_pd_pair(&client, None).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("No prefill workers available")); + } + + // ============= Health Endpoints Tests ============= + + #[tokio::test] + async fn test_health_endpoints() { + let router = create_test_pd_router(); + + // Add healthy workers + let prefill_worker = create_test_worker( + "http://localhost:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let decode_worker = create_test_worker( + "http://localhost:8001".to_string(), + WorkerType::Decode, + true, + ); + + router.prefill_workers.write().unwrap().push(prefill_worker); + router.decode_workers.write().unwrap().push(decode_worker); + + // Test health endpoint + let client = reqwest::Client::new(); + let http_req = TestRequest::default().to_http_request(); + let response = router.health(&client, &http_req).await; + + assert_eq!(response.status(), 200); + + // Test readiness endpoint + let response = router.readiness(); + assert_eq!(response.status(), 200); + } + + // ============= Load Monitoring Tests ============= + + #[tokio::test] + async fn test_load_monitor_updates() { + let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Create load channel + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + router.worker_loads = Arc::new(rx); + + // Simulate load updates + let mut loads = HashMap::new(); + loads.insert("http://worker1".to_string(), 10); + loads.insert("http://worker2".to_string(), 5); + + let _ = tx.send(loads.clone()); + + // Router should receive updates + let received = router.worker_loads.borrow().clone(); + assert_eq!(received.get("http://worker1"), Some(&10)); + assert_eq!(received.get("http://worker2"), Some(&5)); + } + + // ============= Worker Load Tests ============= + + #[test] + fn test_worker_load_metrics() { + let prefill_worker = create_test_worker( + "http://prefill".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let decode_worker = + create_test_worker("http://decode".to_string(), WorkerType::Decode, true); + + // Create load guard for both workers + let _guard = + WorkerLoadGuard::new_multi(vec![prefill_worker.as_ref(), decode_worker.as_ref()]); + + // Load should be incremented + assert_eq!(prefill_worker.load(), 1); + assert_eq!(decode_worker.load(), 1); + + // Drop guard - load should decrement + drop(_guard); + + assert_eq!(prefill_worker.load(), 0); + assert_eq!(decode_worker.load(), 0); + } + + // ============= Concurrent Operations Tests ============= + + #[tokio::test] + async fn test_concurrent_worker_operations() { + let router = Arc::new(create_test_pd_router()); + + let mut handles = vec![]; + + // Spawn tasks to add workers + for i in 0..5 { + let router_clone = Arc::clone(&router); + let url = format!("http://worker{}", i); + let handle = tokio::spawn(async move { + let worker = create_test_worker( + url, + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router_clone.prefill_workers.write().unwrap().push(worker); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + let _ = handle.await; + } + + // Check final state + let workers = router.prefill_workers.read().unwrap(); + assert_eq!(workers.len(), 5); + } +} diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index ceb5fe9e6..a2c0d7e31 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -1,16 +1,3 @@ -//! Comprehensive tests for PrefillDecode (PD) routing functionality -//! -//! This test suite covers: -//! - Phase 1: Basic PD router creation and configuration -//! - Phase 2: Bootstrap injection and request handling -//! - Phase 3: Cache-aware selection (when implemented) -//! -//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. -//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. - -// TODO: This test file needs to be updated for the new configuration structure -// where RoutingMode and PolicyConfig are separate - #[cfg(test)] mod test_pd_routing { use rand::Rng; @@ -921,14 +908,6 @@ mod test_pd_routing { #[test] fn test_policy_type_to_pd_selection_policy_mapping() { - // Document the mapping from PolicyType to PDSelectionPolicy - // This mapping happens in lib.rs when pd_disaggregation=true - - // PolicyType::Random -> PDSelectionPolicy::Random - // PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo - // PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... } - // PolicyType::RoundRobin -> ERROR (not supported in PD mode) - // Test that PDSelectionPolicy doesn't include RoundRobin let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware assert_eq!(