From 873d858b286f09f478eab3a9589c6a35bae60968 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 19 Sep 2025 05:43:23 -0400 Subject: [PATCH] [router] refactor worker to builder pattern 5/n (#10653) --- sgl-router/src/core/worker.rs | 337 +++++++---------------- sgl-router/src/core/worker_registry.rs | 42 +-- sgl-router/src/routers/router_manager.rs | 41 +-- sgl-router/tests/test_pd_routing.rs | 54 +++- 4 files changed, 197 insertions(+), 277 deletions(-) diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 2e7debe51..191c2150e 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,4 +1,4 @@ -use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; +use super::{CircuitBreaker, WorkerError, WorkerResult}; use crate::core::CircuitState; use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; use crate::grpc::SglangSchedulerClient; @@ -525,23 +525,6 @@ pub struct DPAwareWorker { } impl DPAwareWorker { - /// Create a new DP-aware worker of any type - pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self { - use crate::core::BasicWorkerBuilder; - // Create URL with DP rank suffix for identification - let worker_url = format!("{}@{}", base_url, dp_rank); - let base_worker = BasicWorkerBuilder::new(worker_url) - .worker_type(worker_type) - .build(); - - Self { - base_worker, - dp_rank, - dp_size, - base_url, - } - } - /// Create a new DP-aware worker with a pre-configured base worker /// This is primarily used by the builder pattern pub fn with_base_worker( @@ -661,95 +644,6 @@ impl Worker for DPAwareWorker { pub struct WorkerFactory; impl WorkerFactory { - /// Create a BasicWorkerBuilder for customizable worker creation - pub fn builder(url: impl Into) -> BasicWorkerBuilder { - BasicWorkerBuilder::new(url) - } - - /// Create a DPAwareWorkerBuilder for customizable DP-aware worker creation - pub fn dp_builder( - base_url: impl Into, - dp_rank: usize, - dp_size: usize, - ) -> DPAwareWorkerBuilder { - DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size) - } - - /// Create a regular worker - pub fn create_regular(url: String) -> Box { - use crate::core::BasicWorkerBuilder; - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Regular) - .build(), - ) - } - - /// Create a prefill worker with optional bootstrap port - pub fn create_prefill(url: String, bootstrap_port: Option) -> Box { - use crate::core::BasicWorkerBuilder; - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Prefill { bootstrap_port }) - .build(), - ) - } - - /// Create a decode worker - pub fn create_decode(url: String) -> Box { - use crate::core::BasicWorkerBuilder; - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Decode) - .build(), - ) - } - - /// Create a regular worker with custom labels (for multi-router support) - pub fn create_regular_with_labels( - url: String, - labels: std::collections::HashMap, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { - Box::new( - BasicWorkerBuilder::new(url) - .labels(labels) - .circuit_breaker_config(circuit_breaker_config) - .build(), - ) - } - - /// Create a prefill worker with labels - pub fn create_prefill_with_labels( - url: String, - bootstrap_port: Option, - labels: std::collections::HashMap, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Prefill { bootstrap_port }) - .labels(labels) - .circuit_breaker_config(circuit_breaker_config) - .build(), - ) - } - - /// Create a decode worker with labels - pub fn create_decode_with_labels( - url: String, - labels: std::collections::HashMap, - circuit_breaker_config: CircuitBreakerConfig, - ) -> Box { - Box::new( - BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Decode) - .labels(labels) - .circuit_breaker_config(circuit_breaker_config) - .build(), - ) - } - /// Create a DP-aware worker of specified type pub fn create_dp_aware( base_url: String, @@ -757,9 +651,13 @@ impl WorkerFactory { dp_size: usize, worker_type: WorkerType, ) -> Box { - Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type)) + Box::new( + DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size) + .worker_type(worker_type) + .build(), + ) } - + #[allow(dead_code)] /// Get DP size from a worker async fn get_worker_dp_size(url: &str, api_key: &Option) -> WorkerResult { let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url)); @@ -807,81 +705,18 @@ impl WorkerFactory { Ok(dp_size as usize) } - - /// Private helper to create DP-aware workers of any type - async fn create_dp_aware_workers_of_type( - url: &str, - api_key: &Option, - worker_type: WorkerType, - ) -> WorkerResult>> { - let dp_size = Self::get_worker_dp_size(url, api_key).await?; - - let workers = (0..dp_size) - .map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone())) - .collect(); - - Ok(workers) - } - - /// Create DP-aware regular workers from a single URL - pub async fn create_dp_aware_regular_workers( - url: &str, - api_key: &Option, - ) -> WorkerResult>> { - Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await - } - - /// Create DP-aware prefill workers from a single URL - pub async fn create_dp_aware_prefill_workers( - url: &str, - bootstrap_port: Option, - api_key: &Option, - ) -> WorkerResult>> { - Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port }) - .await - } - - /// Create DP-aware decode workers from a single URL - pub async fn create_dp_aware_decode_workers( - url: &str, - api_key: &Option, - ) -> WorkerResult>> { - Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await - } - - /// Create workers based on configuration (for regular router) - pub async fn create_workers( - urls: Vec, - dp_aware: bool, - api_key: &Option, - ) -> WorkerResult>> { - if dp_aware { - // Create futures for all worker creations - let worker_futs = urls - .iter() - .map(|url| Self::create_dp_aware_regular_workers(url, api_key)); - - // Execute all futures concurrently and flatten results - let all_workers = futures::future::try_join_all(worker_futs) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(all_workers) - } else { - Ok(urls - .into_iter() - .map(|url| Self::create_regular(url)) - .collect()) - } - } } /// Convert a list of worker URLs to worker trait objects pub fn urls_to_workers(urls: Vec) -> Vec> { urls.into_iter() - .map(WorkerFactory::create_regular) + .map(|url| { + Box::new( + BasicWorkerBuilder::new(url) + .worker_type(WorkerType::Regular) + .build(), + ) as Box + }) .collect() } @@ -952,7 +787,7 @@ impl HealthChecker { /// Start an async background health checker for a collection of workers pub fn start_health_checker( - workers: std::sync::Arc>>>, + workers: Arc>>>, check_interval_secs: u64, ) -> HealthChecker { let shutdown = Arc::new(AtomicBool::new(false)); @@ -1037,6 +872,7 @@ pub fn start_health_checker( #[cfg(test)] mod tests { use super::*; + use crate::core::CircuitBreakerConfig; use std::thread; use std::time::Duration; @@ -1369,7 +1205,11 @@ mod tests { // Test WorkerFactory #[test] fn test_create_regular_worker() { - let worker = WorkerFactory::create_regular("http://regular:8080".to_string()); + let worker: Box = Box::new( + BasicWorkerBuilder::new("http://regular:8080") + .worker_type(WorkerType::Regular) + .build(), + ); assert_eq!(worker.url(), "http://regular:8080"); assert_eq!(worker.worker_type(), WorkerType::Regular); } @@ -1377,7 +1217,13 @@ mod tests { #[test] fn test_create_prefill_worker() { // With bootstrap port - let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); + let worker1: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9090), + }) + .build(), + ); assert_eq!(worker1.url(), "http://prefill:8080"); assert_eq!( worker1.worker_type(), @@ -1387,7 +1233,13 @@ mod tests { ); // Without bootstrap port - let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None); + let worker2: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: None, + }) + .build(), + ); assert_eq!( worker2.worker_type(), WorkerType::Prefill { @@ -1398,7 +1250,11 @@ mod tests { #[test] fn test_create_decode_worker() { - let worker = WorkerFactory::create_decode("http://decode:8080".to_string()); + let worker: Box = Box::new( + BasicWorkerBuilder::new("http://decode:8080") + .worker_type(WorkerType::Decode) + .build(), + ); assert_eq!(worker.url(), "http://decode:8080"); assert_eq!(worker.worker_type(), WorkerType::Decode); } @@ -1424,9 +1280,21 @@ mod tests { #[test] fn test_load_guard_multiple_workers() { let workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), - WorkerFactory::create_regular("http://w3:8080".to_string()), + Box::new( + BasicWorkerBuilder::new("http://w1:8080") + .worker_type(WorkerType::Regular) + .build(), + ), + Box::new( + BasicWorkerBuilder::new("http://w2:8080") + .worker_type(WorkerType::Regular) + .build(), + ), + Box::new( + BasicWorkerBuilder::new("http://w3:8080") + .worker_type(WorkerType::Regular) + .build(), + ), ]; let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect(); @@ -1492,8 +1360,16 @@ mod tests { #[test] fn test_workers_to_urls() { let workers: Vec> = vec![ - WorkerFactory::create_regular("http://w1:8080".to_string()), - WorkerFactory::create_regular("http://w2:8080".to_string()), + Box::new( + BasicWorkerBuilder::new("http://w1:8080") + .worker_type(WorkerType::Regular) + .build(), + ), + Box::new( + BasicWorkerBuilder::new("http://w2:8080") + .worker_type(WorkerType::Regular) + .build(), + ), ]; let urls = workers_to_urls(&workers); @@ -1544,8 +1420,9 @@ mod tests { #[test] fn test_dp_aware_worker_creation() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4) + .worker_type(WorkerType::Regular) + .build(); assert_eq!(dp_worker.url(), "http://worker1:8080@2"); assert_eq!(dp_worker.base_url(), "http://worker1:8080"); @@ -1557,14 +1434,11 @@ mod tests { #[test] fn test_dp_aware_worker_creation_prefill() { - let dp_worker = DPAwareWorker::new( - "http://worker1:8080".to_string(), - 1, - 2, - WorkerType::Prefill { + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 2) + .worker_type(WorkerType::Prefill { bootstrap_port: Some(9090), - }, - ); + }) + .build(); assert_eq!(dp_worker.url(), "http://worker1:8080@1"); assert!(dp_worker.is_dp_aware()); @@ -1578,8 +1452,9 @@ mod tests { #[test] fn test_dp_aware_worker_creation_decode() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4) + .worker_type(WorkerType::Decode) + .build(); assert_eq!(dp_worker.url(), "http://worker1:8080@0"); assert!(dp_worker.is_dp_aware()); @@ -1588,8 +1463,9 @@ mod tests { #[tokio::test] async fn test_dp_aware_prepare_request() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 3, 8) + .worker_type(WorkerType::Regular) + .build(); let original_req = serde_json::json!({ "prompt": "Hello", @@ -1605,8 +1481,9 @@ mod tests { #[tokio::test] async fn test_dp_aware_prepare_request_invalid() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4) + .worker_type(WorkerType::Regular) + .build(); // Non-object JSON should fail let invalid_req = serde_json::json!("not an object"); @@ -1623,8 +1500,9 @@ mod tests { #[test] fn test_dp_aware_endpoint_url() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 4) + .worker_type(WorkerType::Regular) + .build(); assert_eq!( dp_worker.endpoint_url("/generate"), @@ -1638,8 +1516,9 @@ mod tests { #[test] fn test_dp_aware_worker_delegated_methods() { - let dp_worker = - DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 2) + .worker_type(WorkerType::Regular) + .build(); // Test health status assert!(dp_worker.is_healthy()); @@ -1698,23 +1577,6 @@ mod tests { ); } - #[tokio::test] - async fn test_factory_create_workers_regular() { - let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()]; - - let workers = WorkerFactory::create_workers(urls, false, &None) - .await - .unwrap(); - - assert_eq!(workers.len(), 2); - assert!(!workers[0].is_dp_aware()); - assert!(!workers[1].is_dp_aware()); - assert_eq!(workers[0].url(), "http://w1:8080"); - assert_eq!(workers[1].url(), "http://w2:8080"); - } - - // ===== Circuit Breaker Integration Tests ===== - #[test] fn test_worker_circuit_breaker() { use crate::core::BasicWorkerBuilder; @@ -1779,8 +1641,9 @@ mod tests { #[test] fn test_dp_aware_worker_circuit_breaker() { - let dp_worker = - DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular); + let dp_worker = DPAwareWorkerBuilder::new("http://worker:8080", 0, 2) + .worker_type(WorkerType::Regular) + .build(); // Should have circuit breaker assert!(dp_worker.is_available()); @@ -1800,9 +1663,23 @@ mod tests { #[tokio::test] async fn test_mixed_worker_types() { // Create a mix of worker types - let regular = WorkerFactory::create_regular("http://regular:8080".to_string()); - let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); - let decode = WorkerFactory::create_decode("http://decode:8080".to_string()); + let regular: Box = Box::new( + BasicWorkerBuilder::new("http://regular:8080") + .worker_type(WorkerType::Regular) + .build(), + ); + let prefill: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9090), + }) + .build(), + ); + let decode: Box = Box::new( + BasicWorkerBuilder::new("http://decode:8080") + .worker_type(WorkerType::Decode) + .build(), + ); let dp_aware_regular = WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular); let dp_aware_prefill = WorkerFactory::create_dp_aware( diff --git a/sgl-router/src/core/worker_registry.rs b/sgl-router/src/core/worker_registry.rs index 845ebc223..1ba2a27e5 100644 --- a/sgl-router/src/core/worker_registry.rs +++ b/sgl-router/src/core/worker_registry.rs @@ -424,7 +424,7 @@ pub struct WorkerRegistryStats { #[cfg(test)] mod tests { use super::*; - use crate::core::{CircuitBreakerConfig, WorkerFactory}; + use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig}; use std::collections::HashMap; #[test] @@ -437,10 +437,12 @@ mod tests { labels.insert("priority".to_string(), "50".to_string()); labels.insert("cost".to_string(), "0.8".to_string()); - let worker = WorkerFactory::create_regular_with_labels( - "http://worker1:8080".to_string(), - labels, - CircuitBreakerConfig::default(), + let worker: Box = Box::new( + BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .labels(labels) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), ); // Register worker (WorkerFactory returns Box, convert to Arc) @@ -470,26 +472,32 @@ mod tests { // Create workers for different models let mut labels1 = HashMap::new(); labels1.insert("model_id".to_string(), "llama-3".to_string()); - let worker1 = WorkerFactory::create_regular_with_labels( - "http://worker1:8080".to_string(), - labels1, - CircuitBreakerConfig::default(), + let worker1: Box = Box::new( + BasicWorkerBuilder::new("http://worker1:8080") + .worker_type(WorkerType::Regular) + .labels(labels1) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), ); let mut labels2 = HashMap::new(); labels2.insert("model_id".to_string(), "llama-3".to_string()); - let worker2 = WorkerFactory::create_regular_with_labels( - "http://worker2:8080".to_string(), - labels2, - CircuitBreakerConfig::default(), + let worker2: Box = Box::new( + BasicWorkerBuilder::new("http://worker2:8080") + .worker_type(WorkerType::Regular) + .labels(labels2) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), ); let mut labels3 = HashMap::new(); labels3.insert("model_id".to_string(), "gpt-4".to_string()); - let worker3 = WorkerFactory::create_regular_with_labels( - "http://worker3:8080".to_string(), - labels3, - CircuitBreakerConfig::default(), + let worker3: Box = Box::new( + BasicWorkerBuilder::new("http://worker3:8080") + .worker_type(WorkerType::Regular) + .labels(labels3) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), ); // Register workers diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index a6c204e39..f44b9a8c2 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -5,7 +5,7 @@ //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything use crate::config::RouterConfig; -use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType}; +use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig, Worker, WorkerRegistry, WorkerType}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesRequest, @@ -208,22 +208,29 @@ impl RouterManager { } let worker = match config.worker_type.as_deref() { - Some("prefill") => WorkerFactory::create_prefill_with_labels( - config.url.clone(), - config.bootstrap_port, - labels.clone(), - CircuitBreakerConfig::default(), - ), - Some("decode") => WorkerFactory::create_decode_with_labels( - config.url.clone(), - labels.clone(), - CircuitBreakerConfig::default(), - ), - _ => WorkerFactory::create_regular_with_labels( - config.url.clone(), - labels.clone(), - CircuitBreakerConfig::default(), - ), + Some("prefill") => Box::new( + BasicWorkerBuilder::new(config.url.clone()) + .worker_type(WorkerType::Prefill { + bootstrap_port: config.bootstrap_port, + }) + .labels(labels.clone()) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), + ) as Box, + Some("decode") => Box::new( + BasicWorkerBuilder::new(config.url.clone()) + .worker_type(WorkerType::Decode) + .labels(labels.clone()) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), + ) as Box, + _ => Box::new( + BasicWorkerBuilder::new(config.url.clone()) + .worker_type(WorkerType::Regular) + .labels(labels.clone()) + .circuit_breaker_config(CircuitBreakerConfig::default()) + .build(), + ) as Box, }; // Register worker diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 45651478a..5b0f9dd96 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -4,7 +4,7 @@ mod test_pd_routing { use sglang_router_rs::config::{ CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; - use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::routers::http::pd_types::get_hostname; use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::RouterFactory; @@ -46,11 +46,16 @@ mod test_pd_routing { #[test] fn test_worker_types() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; // Test worker creation for prefill servers - let prefill_worker = - WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); + let prefill_worker: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9000), + }) + .build(), + ); assert_eq!(prefill_worker.url(), "http://prefill:8080"); match prefill_worker.worker_type() { WorkerType::Prefill { bootstrap_port } => { @@ -60,7 +65,11 @@ mod test_pd_routing { } // Test worker creation for decode servers - let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string()); + let decode_worker: Box = Box::new( + BasicWorkerBuilder::new("http://decode:8080") + .worker_type(WorkerType::Decode) + .build(), + ); assert_eq!(decode_worker.url(), "http://decode:8080"); match decode_worker.worker_type() { WorkerType::Decode => (), @@ -68,7 +77,11 @@ mod test_pd_routing { } // Test regular worker creation - let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string()); + let regular_worker: Box = Box::new( + BasicWorkerBuilder::new("http://regular:8080") + .worker_type(WorkerType::Regular) + .build(), + ); assert_eq!(regular_worker.url(), "http://regular:8080"); match regular_worker.worker_type() { WorkerType::Regular => (), @@ -277,8 +290,13 @@ mod test_pd_routing { }); // Create a prefill worker to simulate injection - let prefill_worker = - WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000)); + let prefill_worker: Box = Box::new( + BasicWorkerBuilder::new("http://prefill1:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9000), + }) + .build(), + ); // Extract bootstrap port from worker type let bootstrap_port = match prefill_worker.worker_type() { @@ -660,7 +678,7 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_with_benchmark_requests() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; // Test bootstrap injection with actual benchmark request patterns let mut benchmark_request = json!({ @@ -675,8 +693,13 @@ mod test_pd_routing { }); // Create a prefill worker to simulate injection - let prefill_worker = - WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); + let prefill_worker: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9000), + }) + .build(), + ); // Extract bootstrap port from worker type let bootstrap_port = match prefill_worker.worker_type() { @@ -806,8 +829,13 @@ mod test_pd_routing { }); // Create a prefill worker to simulate injection - let prefill_worker = - WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); + let prefill_worker: Box = Box::new( + BasicWorkerBuilder::new("http://prefill:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9000), + }) + .build(), + ); // Extract bootstrap port from worker type let bootstrap_port = match prefill_worker.worker_type() {