diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index eba5680aa..3579d9c67 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; -use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; +use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, @@ -12,12 +12,11 @@ use sglang_router_rs::routers::http::pd_types::{ }; fn create_test_worker() -> BasicWorker { - BasicWorker::new( - "http://test-server:8000".to_string(), - WorkerType::Prefill { + BasicWorkerBuilder::new("http://test-server:8000") + .worker_type(WorkerType::Prefill { bootstrap_port: Some(5678), - }, - ) + }) + .build() } // Helper function to get bootstrap info from worker diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index cf59c5e07..361bf3be5 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -451,7 +451,7 @@ impl Drop for CacheAwarePolicy { #[cfg(test)] mod tests { use super::*; - use crate::core::{BasicWorker, WorkerType}; + use crate::core::{BasicWorkerBuilder, WorkerType}; #[test] fn test_cache_aware_with_balanced_load() { @@ -462,14 +462,16 @@ mod tests { }; let policy = CacheAwarePolicy::with_config(config); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Initialize the policy with workers @@ -497,8 +499,12 @@ mod tests { max_tree_size: 10000, }); - let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); - let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); + let worker1 = BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(); + let worker2 = BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(); // Create significant load imbalance for _ in 0..20 { @@ -524,14 +530,16 @@ mod tests { }; let policy = CacheAwarePolicy::with_config(config); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; policy.init_workers(&workers); diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs index 7fdf03ba3..34c71bf8e 100644 --- a/sgl-router/src/policies/mod.rs +++ b/sgl-router/src/policies/mod.rs @@ -121,23 +121,26 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Arc]) -> Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w3:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w3:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // All healthy initially diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs index 6452cdc6f..d21f42a46 100644 --- a/sgl-router/src/policies/power_of_two.rs +++ b/sgl-router/src/policies/power_of_two.rs @@ -119,14 +119,20 @@ impl Default for PowerOfTwoPolicy { #[cfg(test)] mod tests { use super::*; - use crate::core::{BasicWorker, WorkerType}; + use crate::core::{BasicWorkerBuilder, WorkerType}; #[test] fn test_power_of_two_selection() { let policy = PowerOfTwoPolicy::new(); - let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); - let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); - let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular); + let worker1 = BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(); + let worker2 = BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(); + let worker3 = BasicWorkerBuilder::new("http://w3:8000") + .worker_type(WorkerType::Regular) + .build(); // Set different loads for _ in 0..10 { @@ -157,14 +163,16 @@ mod tests { fn test_power_of_two_with_cached_loads() { let policy = PowerOfTwoPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Update cached loads @@ -190,10 +198,11 @@ mod tests { #[test] fn test_power_of_two_single_worker() { let policy = PowerOfTwoPolicy::new(); - let workers: Vec> = vec![Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - ))]; + let workers: Vec> = vec![Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + )]; // With single worker, should always select it assert_eq!(policy.select_worker(&workers, None), Some(0)); diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs index 11636c045..492848467 100644 --- a/sgl-router/src/policies/random.rs +++ b/sgl-router/src/policies/random.rs @@ -51,25 +51,28 @@ impl LoadBalancingPolicy for RandomPolicy { #[cfg(test)] mod tests { use super::*; - use crate::core::{BasicWorker, WorkerType}; + use crate::core::{BasicWorkerBuilder, WorkerType}; use std::collections::HashMap; #[test] fn test_random_selection() { let policy = RandomPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w3:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w3:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Test multiple selections to ensure randomness @@ -89,14 +92,16 @@ mod tests { fn test_random_with_unhealthy_workers() { let policy = RandomPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Mark first worker as unhealthy @@ -111,10 +116,11 @@ mod tests { #[test] fn test_random_no_healthy_workers() { let policy = RandomPolicy::new(); - let workers: Vec> = vec![Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - ))]; + let workers: Vec> = vec![Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + )]; workers[0].set_healthy(false); assert_eq!(policy.select_worker(&workers, None), None); diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs index 1b4087224..47e3c6e92 100644 --- a/sgl-router/src/policies/round_robin.rs +++ b/sgl-router/src/policies/round_robin.rs @@ -60,24 +60,27 @@ impl LoadBalancingPolicy for RoundRobinPolicy { #[cfg(test)] mod tests { use super::*; - use crate::core::{BasicWorker, WorkerType}; + use crate::core::{BasicWorkerBuilder, WorkerType}; #[test] fn test_round_robin_selection() { let policy = RoundRobinPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w3:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w3:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Should select workers in order: 0, 1, 2, 0, 1, 2, ... @@ -92,18 +95,21 @@ mod tests { fn test_round_robin_with_unhealthy_workers() { let policy = RoundRobinPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w3:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w3:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Mark middle worker as unhealthy @@ -120,14 +126,16 @@ mod tests { fn test_round_robin_reset() { let policy = RoundRobinPolicy::new(); let workers: Vec> = vec![ - Arc::new(BasicWorker::new( - "http://w1:8000".to_string(), - WorkerType::Regular, - )), - Arc::new(BasicWorker::new( - "http://w2:8000".to_string(), - WorkerType::Regular, - )), + Arc::new( + BasicWorkerBuilder::new("http://w1:8000") + .worker_type(WorkerType::Regular) + .build(), + ), + Arc::new( + BasicWorkerBuilder::new("http://w2:8000") + .worker_type(WorkerType::Regular) + .build(), + ), ]; // Advance the counter diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 4bd9d024d..86f7acb5e 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -2,7 +2,7 @@ use crate::config::types::RetryConfig; use crate::core::{ - BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, + BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, }; use crate::grpc::SglangSchedulerClient; use crate::metrics::RouterMetrics; @@ -130,23 +130,22 @@ impl GrpcPDRouter { let prefill_workers: Vec> = prefill_urls .iter() .map(|(url, bootstrap_port)| { - let worker = BasicWorker::with_connection_mode( - url.clone(), - WorkerType::Prefill { + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Prefill { bootstrap_port: *bootstrap_port, - }, - crate::core::ConnectionMode::Grpc { + }) + .connection_mode(crate::core::ConnectionMode::Grpc { port: *bootstrap_port, - }, - ) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }); + }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .build(); Arc::new(worker) as Arc }) .collect(); @@ -155,19 +154,18 @@ impl GrpcPDRouter { let decode_workers: Vec> = decode_urls .iter() .map(|url| { - let worker = BasicWorker::with_connection_mode( - url.clone(), - WorkerType::Decode, - crate::core::ConnectionMode::Grpc { port: None }, - ) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Decode) + .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .build(); Arc::new(worker) as Arc }) .collect(); diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index ff38e3469..f88cf9ed2 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -2,7 +2,7 @@ use crate::config::types::RetryConfig; use crate::core::{ - BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, + BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, }; use crate::grpc::SglangSchedulerClient; use crate::metrics::RouterMetrics; @@ -108,20 +108,19 @@ impl GrpcRouter { // Move clients from the HashMap to the workers for url in &worker_urls { if let Some(client) = grpc_clients.remove(url) { - let worker = BasicWorker::with_connection_mode( - url.clone(), - WorkerType::Regular, - crate::core::ConnectionMode::Grpc { port: None }, - ) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .with_grpc_client(client); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Regular) + .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .grpc_client(client) + .build(); workers.push(Arc::new(worker) as Arc); } else { diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index a31186177..8eb55b543 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -3,8 +3,8 @@ use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, - WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, + is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, + Worker, WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -389,34 +389,35 @@ impl PDRouter { // Register prefill workers in the registry for (url, port) in prefill_urls { - let worker = BasicWorker::new( - url, - WorkerType::Prefill { + let worker = BasicWorkerBuilder::new(url) + .worker_type(WorkerType::Prefill { bootstrap_port: port, - }, - ) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }); - ctx.worker_registry.register(Arc::new(worker)); - } - - // Register decode workers in the registry - for url in decode_urls { - let worker = BasicWorker::new(url, WorkerType::Decode) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { + }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs, endpoint: ctx.router_config.health_check.endpoint.clone(), failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, - }); + }) + .build(); + ctx.worker_registry.register(Arc::new(worker)); + } + + // Register decode workers in the registry + for url in decode_urls { + let worker = BasicWorkerBuilder::new(url) + .worker_type(WorkerType::Decode) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .build(); ctx.worker_registry.register(Arc::new(worker)); } @@ -2116,7 +2117,7 @@ impl RouterTrait for PDRouter { #[cfg(test)] mod tests { use super::*; - use crate::core::{BasicWorker, WorkerType}; + use crate::core::WorkerType; fn create_test_pd_router() -> PDRouter { let worker_registry = Arc::new(WorkerRegistry::new()); @@ -2139,7 +2140,9 @@ mod tests { } fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box { - let worker = BasicWorker::new(url, worker_type); + let worker = BasicWorkerBuilder::new(url) + .worker_type(worker_type) + .build(); worker.set_healthy(healthy); Box::new(worker) } diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 8fff96bcf..91ad1d948 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,7 +1,7 @@ use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, - WorkerRegistry, WorkerType, + is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, + Worker, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; @@ -87,15 +87,17 @@ impl Router { for url in &worker_urls { // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // For now, create worker without model_id - let worker = BasicWorker::new(url.clone(), WorkerType::Regular) - .with_circuit_breaker_config(core_cb_config.clone()) - .with_health_config(HealthConfig { + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Regular) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs, endpoint: ctx.router_config.health_check.endpoint.clone(), failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, - }); + }) + .build(); let worker_arc = Arc::new(worker); ctx.worker_registry.register(worker_arc.clone()); @@ -991,11 +993,10 @@ impl Router { } info!("Added worker: {}", dp_url); // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint - let new_worker = - BasicWorker::new(dp_url.to_string(), WorkerType::Regular) - .with_circuit_breaker_config( - self.circuit_breaker_config.clone(), - ); + let new_worker = BasicWorkerBuilder::new(dp_url.to_string()) + .worker_type(WorkerType::Regular) + .circuit_breaker_config(self.circuit_breaker_config.clone()) + .build(); let worker_arc = Arc::new(new_worker); self.worker_registry.register(worker_arc.clone()); @@ -1028,11 +1029,10 @@ impl Router { info!("Added worker: {}", worker_url); // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint - let new_worker = - BasicWorker::new(worker_url.to_string(), WorkerType::Regular) - .with_circuit_breaker_config( - self.circuit_breaker_config.clone(), - ); + let new_worker = BasicWorkerBuilder::new(worker_url.to_string()) + .worker_type(WorkerType::Regular) + .circuit_breaker_config(self.circuit_breaker_config.clone()) + .build(); let worker_arc = Arc::new(new_worker); self.worker_registry.register(worker_arc.clone()); @@ -1595,8 +1595,12 @@ mod tests { )); // Register test workers - let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); - let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); + let worker1 = BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .build(); + let worker2 = BasicWorkerBuilder::new("http://worker2:8080") + .worker_type(WorkerType::Regular) + .build(); worker_registry.register(Arc::new(worker1)); worker_registry.register(Arc::new(worker2)); diff --git a/sgl-router/tests/cache_aware_backward_compat_test.rs b/sgl-router/tests/cache_aware_backward_compat_test.rs index 07baa9648..071f0ab11 100644 --- a/sgl-router/tests/cache_aware_backward_compat_test.rs +++ b/sgl-router/tests/cache_aware_backward_compat_test.rs @@ -1,4 +1,4 @@ -use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; +use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy}; use std::collections::HashMap; use std::sync::Arc; @@ -16,13 +16,17 @@ fn test_backward_compatibility_with_empty_model_id() { let policy = CacheAwarePolicy::with_config(config); // Create workers with empty model_id (simulating existing routers) - let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + let worker1 = BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .build(); // No model_id label - should default to "unknown" let mut labels2 = HashMap::new(); labels2.insert("model_id".to_string(), "unknown".to_string()); - let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) - .with_labels(labels2); + let worker2 = BasicWorkerBuilder::new("http://worker2:8080") + .worker_type(WorkerType::Regular) + .labels(labels2) + .build(); // Add workers - should both go to "default" tree policy.add_worker(&worker1); @@ -53,23 +57,31 @@ fn test_mixed_model_ids() { let policy = CacheAwarePolicy::with_config(config); // Create workers with different model_id scenarios - let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + let worker1 = BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .build(); // No model_id label - defaults to "unknown" which goes to "default" tree let mut labels2 = HashMap::new(); labels2.insert("model_id".to_string(), "llama-3".to_string()); - let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) - .with_labels(labels2); + let worker2 = BasicWorkerBuilder::new("http://worker2:8080") + .worker_type(WorkerType::Regular) + .labels(labels2) + .build(); let mut labels3 = HashMap::new(); labels3.insert("model_id".to_string(), "unknown".to_string()); - let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular) - .with_labels(labels3); + let worker3 = BasicWorkerBuilder::new("http://worker3:8080") + .worker_type(WorkerType::Regular) + .labels(labels3) + .build(); let mut labels4 = HashMap::new(); labels4.insert("model_id".to_string(), "llama-3".to_string()); - let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular) - .with_labels(labels4); + let worker4 = BasicWorkerBuilder::new("http://worker4:8080") + .worker_type(WorkerType::Regular) + .labels(labels4) + .build(); // Add all workers policy.add_worker(&worker1); @@ -108,10 +120,14 @@ fn test_remove_worker_by_url_backward_compat() { // Create workers with different model_ids let mut labels1 = HashMap::new(); labels1.insert("model_id".to_string(), "llama-3".to_string()); - let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular) - .with_labels(labels1); + let worker1 = BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .labels(labels1) + .build(); - let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); + let worker2 = BasicWorkerBuilder::new("http://worker2:8080") + .worker_type(WorkerType::Regular) + .build(); // No model_id label - defaults to "unknown" // Add workers