diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index aefbc2000..e344190b2 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -11,6 +11,6 @@ pub mod worker; // Re-export commonly used types at the module level pub use error::{WorkerError, WorkerResult}; pub use worker::{ - start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory, - WorkerLoadGuard, WorkerType, + start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection, + WorkerFactory, WorkerLoadGuard, WorkerType, }; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 58db15991..12cf3b751 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,16 +1,18 @@ use super::{WorkerError, WorkerResult}; use async_trait::async_trait; +use futures; use once_cell::sync::Lazy; +use serde_json; use std::fmt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -// Shared HTTP client for health checks -static HEALTH_CHECK_CLIENT: Lazy = Lazy::new(|| { +// Shared HTTP client for worker operations (health checks, server info, etc.) +static WORKER_CLIENT: Lazy = Lazy::new(|| { reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request .build() - .expect("Failed to create health check HTTP client") + .expect("Failed to create worker HTTP client") }); /// Core worker abstraction that represents a backend service @@ -64,6 +66,43 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Clone the worker (for trait objects) fn clone_worker(&self) -> Box; + + // === DP-aware methods === + + /// Check if this worker is DP-aware + fn is_dp_aware(&self) -> bool { + false + } + + /// Get the base URL without any DP rank suffix + fn base_url(&self) -> &str { + self.url() + } + + /// Get DP rank if this is a DP-aware worker + fn dp_rank(&self) -> Option { + None + } + + /// Get DP size if this worker is part of a DP group + fn dp_size(&self) -> Option { + None + } + + /// Transform a request for DP-aware routing + async fn prepare_request(&self, req: serde_json::Value) -> WorkerResult { + Ok(req) + } + + /// Get the actual endpoint URL for requests + fn endpoint_url(&self, route: &str) -> String { + format!("{}{}", self.base_url(), route) + } + + /// Check if this worker can handle a specific request + fn can_handle(&self, _req: &serde_json::Value) -> bool { + true + } } /// Worker type classification @@ -212,12 +251,7 @@ impl Worker for BasicWorker { let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); // Use the shared client with a custom timeout for this request - match HEALTH_CHECK_CLIENT - .get(&health_url) - .timeout(timeout) - .send() - .await - { + match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await { Ok(response) => { if response.status().is_success() { self.set_healthy(true); @@ -273,6 +307,160 @@ impl Worker for BasicWorker { } } +/// A DP-aware worker that handles data-parallel routing +#[derive(Debug, Clone)] +pub struct DPAwareWorker { + /// The underlying basic worker + base_worker: BasicWorker, + /// DP rank for this worker + dp_rank: usize, + /// Total DP size + dp_size: usize, + /// Base URL without DP suffix + base_url: String, +} + +impl DPAwareWorker { + /// Create a new DP-aware worker of any type + pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self { + // Create URL with DP rank suffix for identification + let worker_url = format!("{}@{}", base_url, dp_rank); + let base_worker = BasicWorker::new(worker_url, worker_type); + + Self { + base_worker, + dp_rank, + dp_size, + base_url, + } + } +} + +#[async_trait] +impl Worker for DPAwareWorker { + fn url(&self) -> &str { + self.base_worker.url() + } + + fn worker_type(&self) -> WorkerType { + self.base_worker.worker_type() + } + + fn is_healthy(&self) -> bool { + self.base_worker.is_healthy() + } + + fn set_healthy(&self, healthy: bool) { + self.base_worker.set_healthy(healthy); + } + + async fn check_health_async(&self) -> WorkerResult<()> { + // Use base URL for health checks + let health_url = format!("{}/health", self.base_url); + let timeout = + std::time::Duration::from_secs(self.base_worker.metadata.health_config.timeout_secs); + + let health_result = async { + let response = WORKER_CLIENT + .get(&health_url) + .timeout(timeout) + .send() + .await + .map_err(|e| format!("Health check request failed: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + Err(format!( + "Health check returned status: {}", + response.status() + )) + } + } + .await; + + match health_result { + Ok(()) => { + self.set_healthy(true); + Ok(()) + } + Err(reason) => { + self.set_healthy(false); + Err(WorkerError::HealthCheckFailed { + url: self.base_url.clone(), + reason, + }) + } + } + } + + fn load(&self) -> usize { + self.base_worker.load() + } + + fn increment_load(&self) { + self.base_worker.increment_load(); + } + + fn decrement_load(&self) { + self.base_worker.decrement_load(); + } + + fn processed_requests(&self) -> usize { + self.base_worker.processed_requests() + } + + fn increment_processed(&self) { + self.base_worker.increment_processed(); + } + + fn metadata(&self) -> &WorkerMetadata { + self.base_worker.metadata() + } + + fn clone_worker(&self) -> Box { + Box::new(self.clone()) + } + + // DP-aware specific implementations + + fn is_dp_aware(&self) -> bool { + true + } + + fn base_url(&self) -> &str { + &self.base_url + } + + fn dp_rank(&self) -> Option { + Some(self.dp_rank) + } + + fn dp_size(&self) -> Option { + Some(self.dp_size) + } + + async fn prepare_request(&self, mut req: serde_json::Value) -> WorkerResult { + // Inject data_parallel_rank into the request + if let Some(map) = req.as_object_mut() { + map.insert( + "data_parallel_rank".to_string(), + serde_json::json!(self.dp_rank), + ); + Ok(req) + } else { + Err(WorkerError::InvalidConfiguration { + message: "Request must be a JSON object for DP-aware routing".to_string(), + }) + } + } + + fn endpoint_url(&self, route: &str) -> String { + // Use base URL for actual requests + format!("{}{}", self.base_url, route) + } +} + /// Worker factory for creating workers of different types pub struct WorkerFactory; @@ -318,6 +506,133 @@ impl WorkerFactory { (regular_workers, prefill_workers, decode_workers) } + + /// Create a DP-aware worker of specified type + pub fn create_dp_aware( + base_url: String, + dp_rank: usize, + dp_size: usize, + worker_type: WorkerType, + ) -> Box { + Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type)) + } + + /// Get DP size from a worker + async fn get_worker_dp_size(url: &str, api_key: &Option) -> WorkerResult { + let mut req_builder = WORKER_CLIENT.get(&format!("{}/get_server_info", url)); + + if let Some(key) = api_key { + req_builder = req_builder.bearer_auth(key); + } + + let response = req_builder + .send() + .await + .map_err(|e| WorkerError::NetworkError { + url: url.to_string(), + error: e.to_string(), + })?; + + if !response.status().is_success() { + return Err(WorkerError::NetworkError { + url: url.to_string(), + error: format!("Server returned: {}", response.status()), + }); + } + + let info: serde_json::Value = + response + .json() + .await + .map_err(|e| WorkerError::NetworkError { + url: url.to_string(), + error: format!("Failed to parse JSON: {}", e), + })?; + + let dp_size = info + .get("dp_size") + .and_then(|v| v.as_u64()) + .ok_or_else(|| WorkerError::InvalidConfiguration { + message: "dp_size not found in server info".to_string(), + })?; + + if dp_size > usize::MAX as u64 { + return Err(WorkerError::InvalidConfiguration { + message: format!("dp_size is too large: {}", dp_size), + }); + } + + Ok(dp_size as usize) + } + + /// Private helper to create DP-aware workers of any type + async fn create_dp_aware_workers_of_type( + url: &str, + api_key: &Option, + worker_type: WorkerType, + ) -> WorkerResult>> { + let dp_size = Self::get_worker_dp_size(url, api_key).await?; + + let workers = (0..dp_size) + .map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone())) + .collect(); + + Ok(workers) + } + + /// Create DP-aware regular workers from a single URL + pub async fn create_dp_aware_regular_workers( + url: &str, + api_key: &Option, + ) -> WorkerResult>> { + Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await + } + + /// Create DP-aware prefill workers from a single URL + pub async fn create_dp_aware_prefill_workers( + url: &str, + bootstrap_port: Option, + api_key: &Option, + ) -> WorkerResult>> { + Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port }) + .await + } + + /// Create DP-aware decode workers from a single URL + pub async fn create_dp_aware_decode_workers( + url: &str, + api_key: &Option, + ) -> WorkerResult>> { + Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await + } + + /// Create workers based on configuration (for regular router) + pub async fn create_workers( + urls: Vec, + dp_aware: bool, + api_key: &Option, + ) -> WorkerResult>> { + if dp_aware { + // Create futures for all worker creations + let worker_futs = urls + .iter() + .map(|url| Self::create_dp_aware_regular_workers(url, api_key)); + + // Execute all futures concurrently and flatten results + let all_workers = futures::future::try_join_all(worker_futs) + .await? + .into_iter() + .flatten() + .collect(); + + Ok(all_workers) + } else { + Ok(urls + .into_iter() + .map(|url| Self::create_regular(url)) + .collect()) + } + } } /// Helper trait for collections of workers @@ -1086,4 +1401,245 @@ mod tests { // Should be well over 1M ops/sec assert!(ops_per_sec > 1_000_000.0); } + + // ===== Tests for DPAwareWorker ===== + + #[test] + fn test_dp_aware_worker_creation() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular); + + assert_eq!(dp_worker.url(), "http://worker1:8080@2"); + assert_eq!(dp_worker.base_url(), "http://worker1:8080"); + assert!(dp_worker.is_dp_aware()); + assert_eq!(dp_worker.dp_rank(), Some(2)); + assert_eq!(dp_worker.dp_size(), Some(4)); + assert_eq!(dp_worker.worker_type(), WorkerType::Regular); + } + + #[test] + fn test_dp_aware_worker_creation_prefill() { + let dp_worker = DPAwareWorker::new( + "http://worker1:8080".to_string(), + 1, + 2, + WorkerType::Prefill { + bootstrap_port: Some(9090), + }, + ); + + assert_eq!(dp_worker.url(), "http://worker1:8080@1"); + assert!(dp_worker.is_dp_aware()); + assert_eq!( + dp_worker.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + } + + #[test] + fn test_dp_aware_worker_creation_decode() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode); + + assert_eq!(dp_worker.url(), "http://worker1:8080@0"); + assert!(dp_worker.is_dp_aware()); + assert_eq!(dp_worker.worker_type(), WorkerType::Decode); + } + + #[tokio::test] + async fn test_dp_aware_prepare_request() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular); + + let original_req = serde_json::json!({ + "prompt": "Hello", + "max_tokens": 100 + }); + + let prepared_req = dp_worker.prepare_request(original_req).await.unwrap(); + + assert_eq!(prepared_req["prompt"], "Hello"); + assert_eq!(prepared_req["max_tokens"], 100); + assert_eq!(prepared_req["data_parallel_rank"], 3); + } + + #[tokio::test] + async fn test_dp_aware_prepare_request_invalid() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular); + + // Non-object JSON should fail + let invalid_req = serde_json::json!("not an object"); + let result = dp_worker.prepare_request(invalid_req).await; + + assert!(result.is_err()); + match result.unwrap_err() { + WorkerError::InvalidConfiguration { message } => { + assert!(message.contains("JSON object")); + } + _ => panic!("Expected InvalidConfiguration error"), + } + } + + #[test] + fn test_dp_aware_endpoint_url() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular); + + assert_eq!( + dp_worker.endpoint_url("/generate"), + "http://worker1:8080/generate" + ); + assert_eq!( + dp_worker.endpoint_url("/health"), + "http://worker1:8080/health" + ); + } + + #[test] + fn test_dp_aware_worker_delegated_methods() { + let dp_worker = + DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular); + + // Test health status + assert!(dp_worker.is_healthy()); + dp_worker.set_healthy(false); + assert!(!dp_worker.is_healthy()); + + // Test load tracking + assert_eq!(dp_worker.load(), 0); + dp_worker.increment_load(); + assert_eq!(dp_worker.load(), 1); + dp_worker.decrement_load(); + assert_eq!(dp_worker.load(), 0); + + // Test processed tracking + assert_eq!(dp_worker.processed_requests(), 0); + dp_worker.increment_processed(); + assert_eq!(dp_worker.processed_requests(), 1); + } + + // ===== Tests for WorkerFactory async methods ===== + + #[tokio::test] + async fn test_factory_create_dp_aware() { + let worker = WorkerFactory::create_dp_aware( + "http://worker1:8080".to_string(), + 1, + 4, + WorkerType::Regular, + ); + + assert_eq!(worker.url(), "http://worker1:8080@1"); + assert!(worker.is_dp_aware()); + assert_eq!(worker.dp_rank(), Some(1)); + assert_eq!(worker.dp_size(), Some(4)); + assert_eq!(worker.worker_type(), WorkerType::Regular); + } + + #[tokio::test] + async fn test_factory_create_dp_aware_prefill() { + let worker = WorkerFactory::create_dp_aware( + "http://worker1:8080".to_string(), + 0, + 2, + WorkerType::Prefill { + bootstrap_port: Some(8090), + }, + ); + + assert_eq!(worker.url(), "http://worker1:8080@0"); + assert!(worker.is_dp_aware()); + assert_eq!( + worker.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(8090) + } + ); + } + + #[tokio::test] + async fn test_factory_create_workers_regular() { + let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()]; + + let workers = WorkerFactory::create_workers(urls, false, &None) + .await + .unwrap(); + + assert_eq!(workers.len(), 2); + assert!(!workers[0].is_dp_aware()); + assert!(!workers[1].is_dp_aware()); + assert_eq!(workers[0].url(), "http://w1:8080"); + assert_eq!(workers[1].url(), "http://w2:8080"); + } + + // ===== Integration tests ===== + + #[tokio::test] + async fn test_mixed_worker_types() { + // Create a mix of worker types + let regular = WorkerFactory::create_regular("http://regular:8080".to_string()); + let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); + let decode = WorkerFactory::create_decode("http://decode:8080".to_string()); + let dp_aware_regular = + WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular); + let dp_aware_prefill = WorkerFactory::create_dp_aware( + "http://dp-prefill:8080".to_string(), + 1, + 2, + WorkerType::Prefill { + bootstrap_port: None, + }, + ); + let dp_aware_decode = WorkerFactory::create_dp_aware( + "http://dp-decode:8080".to_string(), + 0, + 4, + WorkerType::Decode, + ); + + let workers: Vec> = vec![ + regular, + prefill, + decode, + dp_aware_regular, + dp_aware_prefill, + dp_aware_decode, + ]; + + // Test that they all implement Worker trait properly + for worker in &workers { + assert!(worker.is_healthy()); + assert_eq!(worker.load(), 0); + assert_eq!(worker.processed_requests(), 0); + } + + // Test specific behaviors + assert!(!workers[0].is_dp_aware()); // regular + assert!(!workers[1].is_dp_aware()); // prefill + assert!(!workers[2].is_dp_aware()); // decode + assert!(workers[3].is_dp_aware()); // dp_aware_regular + assert!(workers[4].is_dp_aware()); // dp_aware_prefill + assert!(workers[5].is_dp_aware()); // dp_aware_decode + + // Test worker types + assert_eq!(workers[0].worker_type(), WorkerType::Regular); + assert_eq!( + workers[1].worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + assert_eq!(workers[2].worker_type(), WorkerType::Decode); + assert_eq!(workers[3].worker_type(), WorkerType::Regular); + assert_eq!( + workers[4].worker_type(), + WorkerType::Prefill { + bootstrap_port: None + } + ); + assert_eq!(workers[5].worker_type(), WorkerType::Decode); + } }