[router] add ut for pd router (#8208)
This commit is contained in:
@@ -1393,3 +1393,515 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
use crate::policies::{CacheAwarePolicy, RandomPolicy};
|
||||
use crate::routers::pd_types::SingleOrBatch;
|
||||
use actix_web::test::TestRequest;
|
||||
|
||||
fn create_test_pd_router() -> PDRouter {
|
||||
let policy = Arc::new(RandomPolicy::new());
|
||||
|
||||
PDRouter {
|
||||
prefill_workers: Arc::new(RwLock::new(vec![])),
|
||||
decode_workers: Arc::new(RwLock::new(vec![])),
|
||||
policy,
|
||||
prefill_tree: None,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
load_monitor_handle: None,
|
||||
http_client: reqwest::Client::new(),
|
||||
_prefill_health_checker: None,
|
||||
_decode_health_checker: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
|
||||
let worker = BasicWorker::new(url, worker_type);
|
||||
worker.set_healthy(healthy);
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
// ============= Worker Management Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_prefill_server_already_exists() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Add a worker first
|
||||
let worker = create_test_worker(
|
||||
"http://localhost:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8080),
|
||||
},
|
||||
true,
|
||||
);
|
||||
router.prefill_workers.write().unwrap().push(worker);
|
||||
|
||||
// Try to add the same URL again - this would fail during health check in real scenario
|
||||
// For unit test, we test the duplicate check logic
|
||||
let workers = router.prefill_workers.read().unwrap();
|
||||
let exists = workers.iter().any(|w| w.url() == "http://localhost:8000");
|
||||
assert!(exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_prefill_server_success() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Add servers first
|
||||
let worker1 = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let worker2 = create_test_worker(
|
||||
"http://worker2".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8080),
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
router.prefill_workers.write().unwrap().push(worker1);
|
||||
router.prefill_workers.write().unwrap().push(worker2);
|
||||
|
||||
// Remove one
|
||||
let result = router.remove_prefill_server("http://worker1").await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().contains("Successfully removed"));
|
||||
|
||||
let workers = router.prefill_workers.read().unwrap();
|
||||
assert_eq!(workers.len(), 1);
|
||||
assert_eq!(workers[0].url(), "http://worker2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_prefill_server_not_found() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
let result = router.remove_prefill_server("http://nonexistent").await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
PDRouterError::WorkerNotFound { url } => {
|
||||
assert_eq!(url, "http://nonexistent");
|
||||
}
|
||||
_ => panic!("Expected WorkerNotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_decode_server_success() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Add server first
|
||||
let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true);
|
||||
router.decode_workers.write().unwrap().push(worker);
|
||||
|
||||
let result = router.remove_decode_server("http://decode1").await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().contains("Successfully removed"));
|
||||
|
||||
let workers = router.decode_workers.read().unwrap();
|
||||
assert_eq!(workers.len(), 0);
|
||||
}
|
||||
|
||||
// ============= Lock Error Handling Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_lock_operations() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Test read/write locks work correctly
|
||||
{
|
||||
let read_guard = router.prefill_workers.read().unwrap();
|
||||
assert_eq!(read_guard.len(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
let mut write_guard = router.prefill_workers.write().unwrap();
|
||||
write_guard.push(create_test_worker(
|
||||
"http://test".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
let read_guard = router.prefill_workers.read().unwrap();
|
||||
assert_eq!(read_guard.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Cache Tree Integration Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_operations() {
|
||||
let policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
router.prefill_tree = Some(Arc::clone(&tree));
|
||||
|
||||
// Manually add worker and update tree
|
||||
let worker = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
router.prefill_workers.write().unwrap().push(worker);
|
||||
|
||||
// Update tree
|
||||
tree.lock().unwrap().insert("", "http://worker1");
|
||||
|
||||
// Verify tree contains the worker
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
||||
// Since we inserted with empty prefix, we should get a match
|
||||
assert_eq!(tenant, "http://worker1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_rebuild_on_remove() {
|
||||
let policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
router.prefill_tree = Some(Arc::clone(&tree));
|
||||
|
||||
// Add multiple workers
|
||||
let worker1 = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let worker2 = create_test_worker(
|
||||
"http://worker2".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
router.prefill_workers.write().unwrap().push(worker1);
|
||||
router.prefill_workers.write().unwrap().push(worker2);
|
||||
|
||||
// Initialize tree with both workers
|
||||
{
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
tree_guard.insert("", "http://worker1");
|
||||
tree_guard.insert("", "http://worker2");
|
||||
}
|
||||
|
||||
// Remove one worker
|
||||
let result = router.remove_prefill_server("http://worker1").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify tree only contains remaining worker
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
||||
// After rebuild, tree should only have worker2
|
||||
assert_eq!(tenant, "http://worker2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_cache_tree_operations() {
|
||||
let router = create_test_pd_router();
|
||||
assert!(router.prefill_tree.is_none());
|
||||
|
||||
// Add a worker without cache tree
|
||||
let worker = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
router.prefill_workers.write().unwrap().push(worker);
|
||||
|
||||
// Remove should work without tree
|
||||
let result = router.remove_prefill_server("http://worker1").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
// ============= Bootstrap Injection Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_existing_fields() {
|
||||
let mut req = GenerateReqInput {
|
||||
text: Some(SingleOrBatch::Single("Test".to_string())),
|
||||
input_ids: None,
|
||||
stream: false,
|
||||
bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())),
|
||||
bootstrap_port: Some(SingleOrBatch::Single(Some(9999))),
|
||||
bootstrap_room: Some(SingleOrBatch::Single(12345)),
|
||||
other: Value::Object(serde_json::Map::new()),
|
||||
};
|
||||
|
||||
let prefill_worker = create_test_worker(
|
||||
"http://new-host:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8080),
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
// Bootstrap info is added regardless of existing fields
|
||||
let result = req.add_bootstrap_info(prefill_worker.as_ref());
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Bootstrap info should be updated with new values
|
||||
assert_eq!(
|
||||
req.bootstrap_host,
|
||||
Some(SingleOrBatch::Single("new-host".to_string()))
|
||||
);
|
||||
assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080))));
|
||||
// Room should be regenerated (different from original)
|
||||
if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room {
|
||||
assert_ne!(room, 12345);
|
||||
} else {
|
||||
panic!("Expected single room ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_room_generation() {
|
||||
let mut req1 = GenerateReqInput {
|
||||
text: Some(SingleOrBatch::Single("Test".to_string())),
|
||||
input_ids: None,
|
||||
stream: false,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(serde_json::Map::new()),
|
||||
};
|
||||
|
||||
let mut req2 = GenerateReqInput {
|
||||
text: Some(SingleOrBatch::Single("Test".to_string())),
|
||||
input_ids: None,
|
||||
stream: false,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(serde_json::Map::new()),
|
||||
};
|
||||
|
||||
let prefill_worker = create_test_worker(
|
||||
"http://host:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8080),
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
// Add bootstrap info to both requests
|
||||
let _ = req1.add_bootstrap_info(prefill_worker.as_ref());
|
||||
let _ = req2.add_bootstrap_info(prefill_worker.as_ref());
|
||||
|
||||
// Room IDs should be different
|
||||
if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) =
|
||||
(req1.bootstrap_room, req2.bootstrap_room)
|
||||
{
|
||||
assert_ne!(room1, room2, "Room IDs should be unique");
|
||||
} else {
|
||||
panic!("Expected single room IDs");
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Worker Selection Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_healthy_prefill_worker() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Add mix of healthy and unhealthy workers
|
||||
let healthy_worker = create_test_worker(
|
||||
"http://healthy".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let unhealthy_worker = create_test_worker(
|
||||
"http://unhealthy".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
false,
|
||||
);
|
||||
let decode_worker =
|
||||
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
|
||||
|
||||
router
|
||||
.prefill_workers
|
||||
.write()
|
||||
.unwrap()
|
||||
.push(unhealthy_worker);
|
||||
router.prefill_workers.write().unwrap().push(healthy_worker);
|
||||
router.decode_workers.write().unwrap().push(decode_worker);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let result = router.select_pd_pair(&client, None).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (prefill, _decode) = result.unwrap();
|
||||
|
||||
// Should select the healthy worker
|
||||
assert_eq!(prefill.url(), "http://healthy");
|
||||
assert!(prefill.is_healthy());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_worker_lists() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let result = router.select_pd_pair(&client, None).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("No prefill workers available"));
|
||||
}
|
||||
|
||||
// ============= Health Endpoints Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoints() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
// Add healthy workers
|
||||
let prefill_worker = create_test_worker(
|
||||
"http://localhost:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let decode_worker = create_test_worker(
|
||||
"http://localhost:8001".to_string(),
|
||||
WorkerType::Decode,
|
||||
true,
|
||||
);
|
||||
|
||||
router.prefill_workers.write().unwrap().push(prefill_worker);
|
||||
router.decode_workers.write().unwrap().push(decode_worker);
|
||||
|
||||
// Test health endpoint
|
||||
let client = reqwest::Client::new();
|
||||
let http_req = TestRequest::default().to_http_request();
|
||||
let response = router.health(&client, &http_req).await;
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
// Test readiness endpoint
|
||||
let response = router.readiness();
|
||||
assert_eq!(response.status(), 200);
|
||||
}
|
||||
|
||||
// ============= Load Monitoring Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_monitor_updates() {
|
||||
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
|
||||
// Create load channel
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
router.worker_loads = Arc::new(rx);
|
||||
|
||||
// Simulate load updates
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://worker1".to_string(), 10);
|
||||
loads.insert("http://worker2".to_string(), 5);
|
||||
|
||||
let _ = tx.send(loads.clone());
|
||||
|
||||
// Router should receive updates
|
||||
let received = router.worker_loads.borrow().clone();
|
||||
assert_eq!(received.get("http://worker1"), Some(&10));
|
||||
assert_eq!(received.get("http://worker2"), Some(&5));
|
||||
}
|
||||
|
||||
// ============= Worker Load Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_worker_load_metrics() {
|
||||
let prefill_worker = create_test_worker(
|
||||
"http://prefill".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let decode_worker =
|
||||
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
|
||||
|
||||
// Create load guard for both workers
|
||||
let _guard =
|
||||
WorkerLoadGuard::new_multi(vec![prefill_worker.as_ref(), decode_worker.as_ref()]);
|
||||
|
||||
// Load should be incremented
|
||||
assert_eq!(prefill_worker.load(), 1);
|
||||
assert_eq!(decode_worker.load(), 1);
|
||||
|
||||
// Drop guard - load should decrement
|
||||
drop(_guard);
|
||||
|
||||
assert_eq!(prefill_worker.load(), 0);
|
||||
assert_eq!(decode_worker.load(), 0);
|
||||
}
|
||||
|
||||
// ============= Concurrent Operations Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_worker_operations() {
|
||||
let router = Arc::new(create_test_pd_router());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn tasks to add workers
|
||||
for i in 0..5 {
|
||||
let router_clone = Arc::clone(&router);
|
||||
let url = format!("http://worker{}", i);
|
||||
let handle = tokio::spawn(async move {
|
||||
let worker = create_test_worker(
|
||||
url,
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
router_clone.prefill_workers.write().unwrap().push(worker);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
// Check final state
|
||||
let workers = router.prefill_workers.read().unwrap();
|
||||
assert_eq!(workers.len(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
//! Comprehensive tests for PrefillDecode (PD) routing functionality
|
||||
//!
|
||||
//! This test suite covers:
|
||||
//! - Phase 1: Basic PD router creation and configuration
|
||||
//! - Phase 2: Bootstrap injection and request handling
|
||||
//! - Phase 3: Cache-aware selection (when implemented)
|
||||
//!
|
||||
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
|
||||
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||
|
||||
// TODO: This test file needs to be updated for the new configuration structure
|
||||
// where RoutingMode and PolicyConfig are separate
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
@@ -921,14 +908,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||
// Document the mapping from PolicyType to PDSelectionPolicy
|
||||
// This mapping happens in lib.rs when pd_disaggregation=true
|
||||
|
||||
// PolicyType::Random -> PDSelectionPolicy::Random
|
||||
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
|
||||
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
|
||||
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
|
||||
|
||||
// Test that PDSelectionPolicy doesn't include RoundRobin
|
||||
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||
assert_eq!(
|
||||
|
||||
Reference in New Issue
Block a user