From 41d33e4736707cea54aa731055cf88f367befefc Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 19 Jul 2025 14:38:33 -0700 Subject: [PATCH] [router] add ut for worker and errors (#8170) --- sgl-router/src/core/error.rs | 179 ++++++++++ sgl-router/src/core/worker.rs | 610 ++++++++++++++++++++++++++++++++++ 2 files changed, 789 insertions(+) diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs index 02a87dbbc..4d50ccee0 100644 --- a/sgl-router/src/core/error.rs +++ b/sgl-router/src/core/error.rs @@ -55,3 +55,182 @@ impl From for WorkerError { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::error::Error; + + #[test] + fn test_health_check_failed_display() { + let error = WorkerError::HealthCheckFailed { + url: "http://worker1:8080".to_string(), + reason: "Connection refused".to_string(), + }; + assert_eq!( + error.to_string(), + "Health check failed for worker http://worker1:8080: Connection refused" + ); + } + + #[test] + fn test_worker_not_found_display() { + let error = WorkerError::WorkerNotFound { + url: "http://worker2:8080".to_string(), + }; + assert_eq!(error.to_string(), "Worker not found: http://worker2:8080"); + } + + #[test] + fn test_invalid_configuration_display() { + let error = WorkerError::InvalidConfiguration { + message: "Missing port number".to_string(), + }; + assert_eq!( + error.to_string(), + "Invalid worker configuration: Missing port number" + ); + } + + #[test] + fn test_network_error_display() { + let error = WorkerError::NetworkError { + url: "http://worker3:8080".to_string(), + error: "Timeout after 30s".to_string(), + }; + assert_eq!( + error.to_string(), + "Network error for worker http://worker3:8080: Timeout after 30s" + ); + } + + #[test] + fn test_worker_at_capacity_display() { + let error = WorkerError::WorkerAtCapacity { + url: "http://worker4:8080".to_string(), + }; + assert_eq!(error.to_string(), "Worker at capacity: http://worker4:8080"); + } + + #[test] + fn test_worker_error_implements_std_error() { + let error = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + // Verify it implements Error trait + let _: &dyn Error = &error; + assert!(error.source().is_none()); + } + + #[test] + fn test_error_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn test_worker_result_type_alias() { + // Test Ok variant + let result: WorkerResult = Ok(42); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + + // Test Err variant + let error = WorkerError::WorkerNotFound { + url: "test".to_string(), + }; + let result: WorkerResult = Err(error); + assert!(result.is_err()); + } + + #[test] + fn test_empty_url_handling() { + // Test empty URLs in error variants + let error1 = WorkerError::HealthCheckFailed { + url: "".to_string(), + reason: "No connection".to_string(), + }; + assert_eq!( + error1.to_string(), + "Health check failed for worker : No connection" + ); + + let error2 = WorkerError::NetworkError { + url: "".to_string(), + error: "DNS failure".to_string(), + }; + assert_eq!(error2.to_string(), "Network error for worker : DNS failure"); + + let error3 = WorkerError::WorkerNotFound { + url: "".to_string(), + }; + assert_eq!(error3.to_string(), "Worker not found: "); + } + + #[test] + fn test_special_characters_in_messages() { + // Test with special characters + let error = WorkerError::InvalidConfiguration { + message: "Invalid JSON: {\"error\": \"test\"}".to_string(), + }; + assert_eq!( + error.to_string(), + "Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}" + ); + + // Test with unicode + let error2 = WorkerError::HealthCheckFailed { + url: "http://测试:8080".to_string(), + reason: "连接被拒绝".to_string(), + }; + assert_eq!( + error2.to_string(), + "Health check failed for worker http://测试:8080: 连接被拒绝" + ); + } + + #[test] + fn test_very_long_error_messages() { + let long_message = "A".repeat(10000); + let error = WorkerError::InvalidConfiguration { + message: long_message.clone(), + }; + let display = error.to_string(); + assert!(display.contains(&long_message)); + assert_eq!( + display.len(), + "Invalid worker configuration: ".len() + long_message.len() + ); + } + + // Mock reqwest error for testing conversion + #[test] + fn test_reqwest_error_conversion() { + // Test that NetworkError is the correct variant + let network_error = WorkerError::NetworkError { + url: "http://example.com".to_string(), + error: "connection timeout".to_string(), + }; + + match network_error { + WorkerError::NetworkError { url, error } => { + assert_eq!(url, "http://example.com"); + assert_eq!(error, "connection timeout"); + } + _ => panic!("Expected NetworkError variant"), + } + } + + #[test] + fn test_error_equality() { + // WorkerError doesn't implement PartialEq, but we can test that + // the same error construction produces the same display output + let error1 = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + let error2 = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + assert_eq!(error1.to_string(), error2.to_string()); + } +} diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index ae88bdd1c..1aa6766c1 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -452,3 +452,613 @@ pub fn start_health_checker( HealthChecker { handle, shutdown } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::RwLock; + use std::time::Duration; + use tokio::time::timeout; + + // Test WorkerType + #[test] + fn test_worker_type_display() { + assert_eq!(WorkerType::Regular.to_string(), "Regular"); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + } + .to_string(), + "Prefill(bootstrap:8080)" + ); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: None + } + .to_string(), + "Prefill" + ); + assert_eq!(WorkerType::Decode.to_string(), "Decode"); + } + + #[test] + fn test_worker_type_equality() { + assert_eq!(WorkerType::Regular, WorkerType::Regular); + assert_ne!(WorkerType::Regular, WorkerType::Decode); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + }, + WorkerType::Prefill { + bootstrap_port: Some(8080) + } + ); + assert_ne!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + }, + WorkerType::Prefill { + bootstrap_port: Some(8081) + } + ); + } + + #[test] + fn test_worker_type_clone() { + let original = WorkerType::Prefill { + bootstrap_port: Some(8080), + }; + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + // Test HealthConfig + #[test] + fn test_health_config_default() { + let config = HealthConfig::default(); + assert_eq!(config.timeout_secs, 5); + assert_eq!(config.check_interval_secs, 30); + assert_eq!(config.endpoint, "/health"); + } + + #[test] + fn test_health_config_custom() { + let config = HealthConfig { + timeout_secs: 10, + check_interval_secs: 60, + endpoint: "/healthz".to_string(), + }; + assert_eq!(config.timeout_secs, 10); + assert_eq!(config.check_interval_secs, 60); + assert_eq!(config.endpoint, "/healthz"); + } + + // Test BasicWorker + #[test] + fn test_basic_worker_creation() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.url(), "http://test:8080"); + assert_eq!(worker.worker_type(), WorkerType::Regular); + assert!(worker.is_healthy()); + assert_eq!(worker.load(), 0); + assert_eq!(worker.processed_requests(), 0); + } + + #[test] + fn test_worker_with_labels() { + let mut labels = std::collections::HashMap::new(); + labels.insert("env".to_string(), "prod".to_string()); + labels.insert("zone".to_string(), "us-west".to_string()); + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) + .with_labels(labels.clone()); + + assert_eq!(worker.metadata().labels, labels); + } + + #[test] + fn test_worker_with_health_config() { + let custom_config = HealthConfig { + timeout_secs: 15, + check_interval_secs: 45, + endpoint: "/custom-health".to_string(), + }; + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) + .with_health_config(custom_config.clone()); + + assert_eq!(worker.metadata().health_config.timeout_secs, 15); + assert_eq!(worker.metadata().health_config.check_interval_secs, 45); + assert_eq!(worker.metadata().health_config.endpoint, "/custom-health"); + } + + // Test Worker trait implementation + #[test] + fn test_worker_url() { + let worker = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.url(), "http://worker1:8080"); + } + + #[test] + fn test_worker_type_getter() { + let regular = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(regular.worker_type(), WorkerType::Regular); + + let prefill = BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(9090), + }, + ); + assert_eq!( + prefill.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + + let decode = BasicWorker::new("http://test:8080".to_string(), WorkerType::Decode); + assert_eq!(decode.worker_type(), WorkerType::Decode); + } + + #[test] + fn test_health_status() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial state is healthy + assert!(worker.is_healthy()); + + // Set unhealthy + worker.set_healthy(false); + assert!(!worker.is_healthy()); + + // Set healthy again + worker.set_healthy(true); + assert!(worker.is_healthy()); + } + + #[test] + fn test_load_counter_operations() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial load is 0 + assert_eq!(worker.load(), 0); + + // Increment once + worker.increment_load(); + assert_eq!(worker.load(), 1); + + // Increment twice more + worker.increment_load(); + worker.increment_load(); + assert_eq!(worker.load(), 3); + + // Decrement once + worker.decrement_load(); + assert_eq!(worker.load(), 2); + + // Decrement to 0 + worker.decrement_load(); + worker.decrement_load(); + assert_eq!(worker.load(), 0); + + // Decrement below 0 should stay at 0 + worker.decrement_load(); + assert_eq!(worker.load(), 0); + } + + #[test] + fn test_processed_counter() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial count is 0 + assert_eq!(worker.processed_requests(), 0); + + // Increment multiple times + for i in 1..=100 { + worker.increment_processed(); + assert_eq!(worker.processed_requests(), i); + } + } + + #[test] + fn test_clone_worker() { + let original = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + original.increment_load(); + original.increment_processed(); + original.set_healthy(false); + + let cloned = original.clone_worker(); + + // Verify cloned worker has same URL and type + assert_eq!(cloned.url(), original.url()); + assert_eq!(cloned.worker_type(), original.worker_type()); + + // Load counters should be independent (cloned shares the Arc) + assert_eq!(cloned.load(), original.load()); + + // Modify original and verify clone is affected (shared state) + original.increment_load(); + assert_eq!(cloned.load(), original.load()); + } + + // Test concurrent operations + #[tokio::test] + async fn test_concurrent_load_increments() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + let mut handles = vec![]; + + // Spawn 100 tasks incrementing load + for _ in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.increment_load(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final count should be 100 + assert_eq!(worker.load(), 100); + } + + #[tokio::test] + async fn test_concurrent_load_decrements() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + // Set initial load to 100 + for _ in 0..100 { + worker.increment_load(); + } + assert_eq!(worker.load(), 100); + + let mut handles = vec![]; + + // Spawn 100 tasks decrementing load + for _ in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.decrement_load(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final count should be 0 + assert_eq!(worker.load(), 0); + } + + #[tokio::test] + async fn test_concurrent_health_updates() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + let mut handles = vec![]; + + // Spawn threads randomly setting health status + for i in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.set_healthy(i % 2 == 0); + tokio::time::sleep(Duration::from_micros(10)).await; + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final state should be deterministic (last write wins) + // We can't predict the exact final state due to scheduling, + // but we can verify no data corruption occurred + let final_health = worker.is_healthy(); + assert!(final_health == true || final_health == false); + } + + // Test WorkerFactory + #[test] + fn test_create_regular_worker() { + let worker = WorkerFactory::create_regular("http://regular:8080".to_string()); + assert_eq!(worker.url(), "http://regular:8080"); + assert_eq!(worker.worker_type(), WorkerType::Regular); + } + + #[test] + fn test_create_prefill_worker() { + // With bootstrap port + let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); + assert_eq!(worker1.url(), "http://prefill:8080"); + assert_eq!( + worker1.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + + // Without bootstrap port + let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None); + assert_eq!( + worker2.worker_type(), + WorkerType::Prefill { + bootstrap_port: None + } + ); + } + + #[test] + fn test_create_decode_worker() { + let worker = WorkerFactory::create_decode("http://decode:8080".to_string()); + assert_eq!(worker.url(), "http://decode:8080"); + assert_eq!(worker.worker_type(), WorkerType::Decode); + } + + #[test] + fn test_create_from_urls() { + let regular_urls = vec![ + "http://regular1:8080".to_string(), + "http://regular2:8080".to_string(), + ]; + let prefill_urls = vec![ + ("http://prefill1:8080".to_string(), Some(9090)), + ("http://prefill2:8080".to_string(), None), + ]; + let decode_urls = vec![ + "http://decode1:8080".to_string(), + "http://decode2:8080".to_string(), + ]; + + let (regular, prefill, decode) = + WorkerFactory::create_from_urls(regular_urls, prefill_urls, decode_urls); + + assert_eq!(regular.len(), 2); + assert_eq!(prefill.len(), 2); + assert_eq!(decode.len(), 2); + + assert_eq!(regular[0].url(), "http://regular1:8080"); + assert_eq!(prefill[0].url(), "http://prefill1:8080"); + assert_eq!(decode[0].url(), "http://decode1:8080"); + } + + // Test WorkerCollection trait + #[test] + fn test_healthy_workers_filter() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Set some workers unhealthy + workers[0].set_healthy(false); + workers[2].set_healthy(false); + + let healthy = workers.healthy_workers(); + assert_eq!(healthy.len(), 1); + assert_eq!(healthy[0].url(), "http://w2:8080"); + } + + #[test] + fn test_total_load_calculation() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Set different loads + workers[0].increment_load(); + workers[0].increment_load(); // load = 2 + + workers[1].increment_load(); + workers[1].increment_load(); + workers[1].increment_load(); // load = 3 + + workers[2].increment_load(); // load = 1 + + assert_eq!(workers.total_load(), 6); + } + + #[test] + fn test_find_worker() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Found case + let found = workers.find_worker("http://w2:8080"); + assert!(found.is_some()); + assert_eq!(found.unwrap().url(), "http://w2:8080"); + + // Not found case + let not_found = workers.find_worker("http://w4:8080"); + assert!(not_found.is_none()); + } + + #[test] + fn test_find_worker_mut() { + let mut workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + ]; + + // Find and modify + if let Some(worker) = workers.find_worker_mut("http://w1:8080") { + worker.set_healthy(false); + } + + // Verify modification + assert!(!workers[0].is_healthy()); + assert!(workers[1].is_healthy()); + } + + // Test WorkerLoadGuard + #[test] + fn test_load_guard_single_worker() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.load(), 0); + + { + let _guard = WorkerLoadGuard::new(&worker); + assert_eq!(worker.load(), 1); + } + + // Guard dropped, load decremented + assert_eq!(worker.load(), 0); + } + + #[test] + fn test_load_guard_multiple_workers() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect(); + + { + let _guard = WorkerLoadGuard::new_multi(worker_refs); + // All loads incremented + assert_eq!(workers[0].load(), 1); + assert_eq!(workers[1].load(), 1); + assert_eq!(workers[2].load(), 1); + } + + // All loads decremented + assert_eq!(workers[0].load(), 0); + assert_eq!(workers[1].load(), 0); + assert_eq!(workers[2].load(), 0); + } + + #[test] + fn test_load_guard_panic_safety() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + assert_eq!(worker.load(), 0); + + // Clone for use inside catch_unwind + let worker_clone = Arc::clone(&worker); + + // This will panic, but the guard should still clean up + let result = std::panic::catch_unwind(|| { + let _guard = WorkerLoadGuard::new(worker_clone.as_ref()); + assert_eq!(worker_clone.load(), 1); + panic!("Test panic"); + }); + + // Verify panic occurred + assert!(result.is_err()); + + // Load should be decremented even after panic + assert_eq!(worker.load(), 0); + } + + // Test helper functions + #[test] + fn test_urls_to_workers() { + let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()]; + + let workers = urls_to_workers(urls); + assert_eq!(workers.len(), 2); + assert_eq!(workers[0].url(), "http://w1:8080"); + assert_eq!(workers[1].url(), "http://w2:8080"); + assert_eq!(workers[0].worker_type(), WorkerType::Regular); + } + + #[test] + fn test_workers_to_urls() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + ]; + + let urls = workers_to_urls(&workers); + assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]); + } + + // Test synchronous health check wrapper + #[test] + fn test_check_health_sync_wrapper() { + // We can't easily test the actual HTTP call without mocking, + // but we can verify the sync wrapper works + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // This will fail because there's no server at this URL, + // but it tests that the sync wrapper doesn't panic + let result = worker.check_health(); + assert!(result.is_err()); + } + + // Test HealthChecker background task + #[tokio::test] + async fn test_health_checker_startup() { + let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + "http://w1:8080".to_string(), + )])); + + let checker = start_health_checker(workers.clone(), 60); + + // Verify it starts without panic + tokio::time::sleep(Duration::from_millis(100)).await; + + // Shutdown + checker.shutdown().await; + } + + #[tokio::test] + async fn test_health_checker_shutdown() { + let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + "http://w1:8080".to_string(), + )])); + + let checker = start_health_checker(workers.clone(), 60); + + // Shutdown should complete quickly + let shutdown_result = timeout(Duration::from_secs(1), checker.shutdown()).await; + assert!(shutdown_result.is_ok()); + } + + // Performance test for load counter + #[test] + fn test_load_counter_performance() { + use std::time::Instant; + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + let iterations = 1_000_000; + + let start = Instant::now(); + for _ in 0..iterations { + worker.increment_load(); + } + let duration = start.elapsed(); + + let ops_per_sec = iterations as f64 / duration.as_secs_f64(); + println!("Load counter operations per second: {:.0}", ops_per_sec); + + // Should be well over 1M ops/sec + assert!(ops_per_sec > 1_000_000.0); + } +}