diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index 86af5511a..682d6b2f2 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -22,7 +22,7 @@ pub use error::{WorkerError, WorkerResult}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use worker::{ start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, - Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, + Worker, WorkerFactory, WorkerLoadGuard, WorkerType, }; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index e5a339feb..2e7debe51 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -328,15 +328,15 @@ pub struct WorkerMetadata { /// Basic worker implementation #[derive(Clone)] pub struct BasicWorker { - metadata: WorkerMetadata, - load_counter: Arc, - processed_counter: Arc, - healthy: Arc, - consecutive_failures: Arc, - consecutive_successes: Arc, - circuit_breaker: CircuitBreaker, + pub metadata: WorkerMetadata, + pub load_counter: Arc, + pub processed_counter: Arc, + pub healthy: Arc, + pub consecutive_failures: Arc, + pub consecutive_successes: Arc, + pub circuit_breaker: CircuitBreaker, /// Optional gRPC client for gRPC workers - grpc_client: Option>>, + pub grpc_client: Option>>, } impl fmt::Debug for BasicWorker { @@ -351,56 +351,6 @@ impl fmt::Debug for BasicWorker { } impl BasicWorker { - pub fn new(url: String, worker_type: WorkerType) -> Self { - Self::with_connection_mode(url, worker_type, ConnectionMode::Http) - } - - pub fn with_connection_mode( - url: String, - worker_type: WorkerType, - connection_mode: ConnectionMode, - ) -> Self { - let metadata = WorkerMetadata { - url: url.clone(), - worker_type, - connection_mode, - labels: std::collections::HashMap::new(), - health_config: HealthConfig::default(), - }; - - Self { - metadata, - load_counter: Arc::new(AtomicUsize::new(0)), - processed_counter: Arc::new(AtomicUsize::new(0)), - healthy: Arc::new(AtomicBool::new(true)), - consecutive_failures: Arc::new(AtomicUsize::new(0)), - consecutive_successes: Arc::new(AtomicUsize::new(0)), - circuit_breaker: CircuitBreaker::new(), - grpc_client: None, - } - } - - pub fn with_labels(mut self, labels: std::collections::HashMap) -> Self { - self.metadata.labels = labels; - self - } - - pub fn with_health_config(mut self, config: HealthConfig) -> Self { - self.metadata.health_config = config; - self - } - - pub fn with_circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self { - self.circuit_breaker = CircuitBreaker::with_config(config); - self - } - - /// Set the gRPC client for gRPC workers - pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self { - self.grpc_client = Some(Arc::new(Mutex::new(client))); - self - } - pub fn normalised_url(&self) -> WorkerResult<&str> { if self.url().contains("@") { // Need to extract the URL from "http://host:port@dp_rank" @@ -577,9 +527,12 @@ pub struct DPAwareWorker { impl DPAwareWorker { /// Create a new DP-aware worker of any type pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self { + use crate::core::BasicWorkerBuilder; // Create URL with DP rank suffix for identification let worker_url = format!("{}@{}", base_url, dp_rank); - let base_worker = BasicWorker::new(worker_url, worker_type); + let base_worker = BasicWorkerBuilder::new(worker_url) + .worker_type(worker_type) + .build(); Self { base_worker, @@ -724,107 +677,30 @@ impl WorkerFactory { /// Create a regular worker pub fn create_regular(url: String) -> Box { - Box::new(BasicWorker::new(url, WorkerType::Regular)) - } - - /// Create a regular worker with custom circuit breaker configuration - pub fn create_regular_with_config( - url: String, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { + use crate::core::BasicWorkerBuilder; Box::new( BasicWorkerBuilder::new(url) - .circuit_breaker_config(circuit_breaker_config) + .worker_type(WorkerType::Regular) .build(), ) } /// Create a prefill worker with optional bootstrap port pub fn create_prefill(url: String, bootstrap_port: Option) -> Box { - Box::new(BasicWorker::new( - url, - WorkerType::Prefill { bootstrap_port }, - )) - } - - /// Create a prefill worker with custom circuit breaker configuration - pub fn create_prefill_with_config( - url: String, - bootstrap_port: Option, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { + use crate::core::BasicWorkerBuilder; Box::new( BasicWorkerBuilder::new(url) .worker_type(WorkerType::Prefill { bootstrap_port }) - .circuit_breaker_config(circuit_breaker_config) .build(), ) } /// Create a decode worker pub fn create_decode(url: String) -> Box { - Box::new(BasicWorker::new(url, WorkerType::Decode)) - } - - /// Create a decode worker with custom circuit breaker configuration - pub fn create_decode_with_config( - url: String, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { + use crate::core::BasicWorkerBuilder; Box::new( BasicWorkerBuilder::new(url) .worker_type(WorkerType::Decode) - .circuit_breaker_config(circuit_breaker_config) - .build(), - ) - } - - /// Create workers from URLs with automatic type detection - #[allow(clippy::type_complexity)] - pub fn create_from_urls( - regular_urls: Vec, - prefill_urls: Vec<(String, Option)>, - decode_urls: Vec, - ) -> ( - Vec>, - Vec>, - Vec>, - ) { - let regular_workers: Vec> = - regular_urls.into_iter().map(Self::create_regular).collect(); - - let prefill_workers: Vec> = prefill_urls - .into_iter() - .map(|(url, port)| Self::create_prefill(url, port)) - .collect(); - - let decode_workers: Vec> = - decode_urls.into_iter().map(Self::create_decode).collect(); - - (regular_workers, prefill_workers, decode_workers) - } - - /// Create a gRPC worker - pub fn create_grpc(url: String, worker_type: WorkerType, port: Option) -> Box { - Box::new(BasicWorker::with_connection_mode( - url, - worker_type, - ConnectionMode::Grpc { port }, - )) - } - - /// Create a gRPC worker with custom circuit breaker configuration - pub fn create_grpc_with_config( - url: String, - worker_type: WorkerType, - port: Option, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(worker_type) - .connection_mode(ConnectionMode::Grpc { port }) - .circuit_breaker_config(circuit_breaker_config) .build(), ) } @@ -1002,35 +878,6 @@ impl WorkerFactory { } } -/// Helper trait for collections of workers -pub trait WorkerCollection { - fn healthy_workers(&self) -> Vec<&dyn Worker>; - fn total_load(&self) -> usize; - fn find_worker(&self, url: &str) -> Option<&dyn Worker>; - fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box>; -} - -impl WorkerCollection for Vec> { - fn healthy_workers(&self) -> Vec<&dyn Worker> { - self.iter() - .filter(|w| w.is_healthy()) - .map(|w| w.as_ref()) - .collect() - } - - fn total_load(&self) -> usize { - self.iter().map(|w| w.load()).sum() - } - - fn find_worker(&self, url: &str) -> Option<&dyn Worker> { - self.iter().find(|w| w.url() == url).map(|w| w.as_ref()) - } - - fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box> { - self.iter_mut().find(|w| w.url() == url) - } -} - /// Convert a list of worker URLs to worker trait objects pub fn urls_to_workers(urls: Vec) -> Vec> { urls.into_iter() @@ -1275,7 +1122,10 @@ mod tests { // Test BasicWorker #[test] fn test_basic_worker_creation() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); assert_eq!(worker.url(), "http://test:8080"); assert_eq!(worker.worker_type(), WorkerType::Regular); assert!(worker.is_healthy()); @@ -1289,8 +1139,11 @@ mod tests { labels.insert("env".to_string(), "prod".to_string()); labels.insert("zone".to_string(), "us-west".to_string()); - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) - .with_labels(labels.clone()); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .labels(labels.clone()) + .build(); assert_eq!(worker.metadata().labels, labels); } @@ -1305,8 +1158,11 @@ mod tests { success_threshold: 2, }; - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) - .with_health_config(custom_config.clone()); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .health_config(custom_config.clone()) + .build(); assert_eq!(worker.metadata().health_config.timeout_secs, 15); assert_eq!(worker.metadata().health_config.check_interval_secs, 45); @@ -1316,21 +1172,26 @@ mod tests { // Test Worker trait implementation #[test] fn test_worker_url() { - let worker = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .build(); assert_eq!(worker.url(), "http://worker1:8080"); } #[test] fn test_worker_type_getter() { - let regular = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let regular = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); assert_eq!(regular.worker_type(), WorkerType::Regular); - let prefill = BasicWorker::new( - "http://test:8080".to_string(), - WorkerType::Prefill { + let prefill = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Prefill { bootstrap_port: Some(9090), - }, - ); + }) + .build(); assert_eq!( prefill.worker_type(), WorkerType::Prefill { @@ -1338,13 +1199,18 @@ mod tests { } ); - let decode = BasicWorker::new("http://test:8080".to_string(), WorkerType::Decode); + let decode = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Decode) + .build(); assert_eq!(decode.worker_type(), WorkerType::Decode); } #[test] fn test_health_status() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); // Initial state is healthy assert!(worker.is_healthy()); @@ -1360,7 +1226,10 @@ mod tests { #[test] fn test_load_counter_operations() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); // Initial load is 0 assert_eq!(worker.load(), 0); @@ -1390,7 +1259,10 @@ mod tests { #[test] fn test_processed_counter() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); // Initial count is 0 assert_eq!(worker.processed_requests(), 0); @@ -1405,10 +1277,12 @@ mod tests { // Test concurrent operations #[tokio::test] async fn test_concurrent_load_increments() { - let worker = Arc::new(BasicWorker::new( - "http://test:8080".to_string(), - WorkerType::Regular, - )); + use crate::core::BasicWorkerBuilder; + let worker = Arc::new( + BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(), + ); let mut handles = vec![]; @@ -1432,10 +1306,12 @@ mod tests { #[tokio::test] async fn test_concurrent_load_decrements() { - let worker = Arc::new(BasicWorker::new( - "http://test:8080".to_string(), - WorkerType::Regular, - )); + use crate::core::BasicWorkerBuilder; + let worker = Arc::new( + BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(), + ); // Set initial load to 100 for _ in 0..100 { @@ -1465,10 +1341,12 @@ mod tests { #[tokio::test] async fn test_concurrent_health_updates() { - let worker = Arc::new(BasicWorker::new( - "http://test:8080".to_string(), - WorkerType::Regular, - )); + use crate::core::BasicWorkerBuilder; + let worker = Arc::new( + BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(), + ); let mut handles = vec![]; @@ -1525,111 +1403,13 @@ mod tests { assert_eq!(worker.worker_type(), WorkerType::Decode); } - #[test] - fn test_create_from_urls() { - let regular_urls = vec![ - "http://regular1:8080".to_string(), - "http://regular2:8080".to_string(), - ]; - let prefill_urls = vec![ - ("http://prefill1:8080".to_string(), Some(9090)), - ("http://prefill2:8080".to_string(), None), - ]; - let decode_urls = vec![ - "http://decode1:8080".to_string(), - "http://decode2:8080".to_string(), - ]; - - let (regular, prefill, decode) = - WorkerFactory::create_from_urls(regular_urls, prefill_urls, decode_urls); - - assert_eq!(regular.len(), 2); - assert_eq!(prefill.len(), 2); - assert_eq!(decode.len(), 2); - - assert_eq!(regular[0].url(), "http://regular1:8080"); - assert_eq!(prefill[0].url(), "http://prefill1:8080"); - assert_eq!(decode[0].url(), "http://decode1:8080"); - } - - // Test WorkerCollection trait - #[test] - fn test_healthy_workers_filter() { - let workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), - WorkerFactory::create_regular("http://w3:8080".to_string()), - ]; - - // Set some workers unhealthy - workers[0].set_healthy(false); - workers[2].set_healthy(false); - - let healthy = workers.healthy_workers(); - assert_eq!(healthy.len(), 1); - assert_eq!(healthy[0].url(), "http://w2:8080"); - } - - #[test] - fn test_total_load_calculation() { - let workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), - WorkerFactory::create_regular("http://w3:8080".to_string()), - ]; - - // Set different loads - workers[0].increment_load(); - workers[0].increment_load(); // load = 2 - - workers[1].increment_load(); - workers[1].increment_load(); - workers[1].increment_load(); // load = 3 - - workers[2].increment_load(); // load = 1 - - assert_eq!(workers.total_load(), 6); - } - - #[test] - fn test_find_worker() { - let workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), - WorkerFactory::create_regular("http://w3:8080".to_string()), - ]; - - // Found case - let found = workers.find_worker("http://w2:8080"); - assert!(found.is_some()); - assert_eq!(found.unwrap().url(), "http://w2:8080"); - - // Not found case - let not_found = workers.find_worker("http://w4:8080"); - assert!(not_found.is_none()); - } - - #[test] - fn test_find_worker_mut() { - let mut workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), - ]; - - // Find and modify - if let Some(worker) = workers.find_worker_mut("http://w1:8080") { - worker.set_healthy(false); - } - - // Verify modification - assert!(!workers[0].is_healthy()); - assert!(workers[1].is_healthy()); - } - // Test WorkerLoadGuard #[test] fn test_load_guard_single_worker() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); assert_eq!(worker.load(), 0); { @@ -1667,10 +1447,12 @@ mod tests { #[test] fn test_load_guard_panic_safety() { - let worker = Arc::new(BasicWorker::new( - "http://test:8080".to_string(), - WorkerType::Regular, - )); + use crate::core::BasicWorkerBuilder; + let worker = Arc::new( + BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(), + ); assert_eq!(worker.load(), 0); // Clone for use inside catch_unwind @@ -1723,7 +1505,10 @@ mod tests { fn test_check_health_sync_wrapper() { // We can't easily test the actual HTTP call without mocking, // but we can verify the sync wrapper works - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); // This will fail because there's no server at this URL, // but it tests that the sync wrapper doesn't panic @@ -1734,9 +1519,12 @@ mod tests { // Performance test for load counter #[test] fn test_load_counter_performance() { + use crate::core::BasicWorkerBuilder; use std::time::Instant; - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); let iterations = 1_000_000; let start = Instant::now(); @@ -1929,7 +1717,10 @@ mod tests { #[test] fn test_worker_circuit_breaker() { - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .build(); // Initial state should be available assert!(worker.is_available()); @@ -1962,8 +1753,11 @@ mod tests { window_duration: Duration::from_secs(60), }; - let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) - .with_circuit_breaker_config(config); + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") + .worker_type(WorkerType::Regular) + .circuit_breaker_config(config) + .build(); // Should open after 2 failures worker.record_outcome(false); diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 741326a7c..94828a870 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -1,5 +1,7 @@ -use super::circuit_breaker::CircuitBreakerConfig; -use super::worker::{BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerType}; +use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +use super::worker::{ + BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, +}; use crate::grpc::client::SglangSchedulerClient; use std::collections::HashMap; @@ -88,23 +90,30 @@ impl BasicWorkerBuilder { /// Build the BasicWorker instance pub fn build(self) -> BasicWorker { - // Use the existing constructor methods for now - let mut worker = - BasicWorker::with_connection_mode(self.url, self.worker_type, self.connection_mode); + use std::sync::{ + atomic::{AtomicBool, AtomicUsize}, + Arc, + }; + use tokio::sync::Mutex; - // Apply optional configurations using existing methods - if !self.labels.is_empty() { - worker = worker.with_labels(self.labels); + let metadata = WorkerMetadata { + url: self.url.clone(), + worker_type: self.worker_type, + connection_mode: self.connection_mode, + labels: self.labels, + health_config: self.health_config, + }; + + BasicWorker { + metadata, + load_counter: Arc::new(AtomicUsize::new(0)), + processed_counter: Arc::new(AtomicUsize::new(0)), + healthy: Arc::new(AtomicBool::new(true)), + consecutive_failures: Arc::new(AtomicUsize::new(0)), + consecutive_successes: Arc::new(AtomicUsize::new(0)), + circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config), + grpc_client: self.grpc_client.map(|client| Arc::new(Mutex::new(client))), } - - worker = worker.with_health_config(self.health_config); - worker = worker.with_circuit_breaker_config(self.circuit_breaker_config); - - if let Some(client) = self.grpc_client { - worker = worker.with_grpc_client(client); - } - - worker } } diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 8eb55b543..df2d2e987 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -4,7 +4,7 @@ use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{ is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, - Worker, WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, + Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -220,13 +220,12 @@ impl PDRouter { // Create Worker for the new prefill server with circuit breaker configuration // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint - let worker = WorkerFactory::create_prefill_with_config( - url.clone(), - bootstrap_port, - self.circuit_breaker_config.clone(), - ); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Prefill { bootstrap_port }) + .circuit_breaker_config(self.circuit_breaker_config.clone()) + .build(); - let worker_arc: Arc = Arc::from(worker); + let worker_arc: Arc = Arc::new(worker); // Register the worker in the registry self.worker_registry.register(worker_arc.clone()); @@ -261,12 +260,12 @@ impl PDRouter { // Create Worker for the new decode server with circuit breaker configuration // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint - let worker = WorkerFactory::create_decode_with_config( - url.clone(), - self.circuit_breaker_config.clone(), - ); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Decode) + .circuit_breaker_config(self.circuit_breaker_config.clone()) + .build(); - let worker_arc: Arc = Arc::from(worker); + let worker_arc: Arc = Arc::new(worker); // Register the worker in the registry self.worker_registry.register(worker_arc.clone());