[router] refactor worker to builder pattern 5/n (#10653)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||
use crate::core::CircuitState;
|
||||
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
@@ -525,23 +525,6 @@ pub struct DPAwareWorker {
|
||||
}
|
||||
|
||||
impl DPAwareWorker {
|
||||
/// Create a new DP-aware worker of any type
|
||||
pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
// Create URL with DP rank suffix for identification
|
||||
let worker_url = format!("{}@{}", base_url, dp_rank);
|
||||
let base_worker = BasicWorkerBuilder::new(worker_url)
|
||||
.worker_type(worker_type)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
base_worker,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new DP-aware worker with a pre-configured base worker
|
||||
/// This is primarily used by the builder pattern
|
||||
pub fn with_base_worker(
|
||||
@@ -661,95 +644,6 @@ impl Worker for DPAwareWorker {
|
||||
pub struct WorkerFactory;
|
||||
|
||||
impl WorkerFactory {
|
||||
/// Create a BasicWorkerBuilder for customizable worker creation
|
||||
pub fn builder(url: impl Into<String>) -> BasicWorkerBuilder {
|
||||
BasicWorkerBuilder::new(url)
|
||||
}
|
||||
|
||||
/// Create a DPAwareWorkerBuilder for customizable DP-aware worker creation
|
||||
pub fn dp_builder(
|
||||
base_url: impl Into<String>,
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
) -> DPAwareWorkerBuilder {
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
}
|
||||
|
||||
/// Create a regular worker
|
||||
pub fn create_regular(url: String) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a prefill worker with optional bootstrap port
|
||||
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a decode worker
|
||||
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a regular worker with custom labels (for multi-router support)
|
||||
pub fn create_regular_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a prefill worker with labels
|
||||
pub fn create_prefill_with_labels(
|
||||
url: String,
|
||||
bootstrap_port: Option<u16>,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a decode worker with labels
|
||||
pub fn create_decode_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
@@ -757,9 +651,13 @@ impl WorkerFactory {
|
||||
dp_size: usize,
|
||||
worker_type: WorkerType,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type))
|
||||
Box::new(
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
.worker_type(worker_type)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
@@ -807,81 +705,18 @@ impl WorkerFactory {
|
||||
|
||||
Ok(dp_size as usize)
|
||||
}
|
||||
|
||||
/// Private helper to create DP-aware workers of any type
|
||||
async fn create_dp_aware_workers_of_type(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
worker_type: WorkerType,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
let dp_size = Self::get_worker_dp_size(url, api_key).await?;
|
||||
|
||||
let workers = (0..dp_size)
|
||||
.map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone()))
|
||||
.collect();
|
||||
|
||||
Ok(workers)
|
||||
}
|
||||
|
||||
/// Create DP-aware regular workers from a single URL
|
||||
pub async fn create_dp_aware_regular_workers(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await
|
||||
}
|
||||
|
||||
/// Create DP-aware prefill workers from a single URL
|
||||
pub async fn create_dp_aware_prefill_workers(
|
||||
url: &str,
|
||||
bootstrap_port: Option<u16>,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port })
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create DP-aware decode workers from a single URL
|
||||
pub async fn create_dp_aware_decode_workers(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await
|
||||
}
|
||||
|
||||
/// Create workers based on configuration (for regular router)
|
||||
pub async fn create_workers(
|
||||
urls: Vec<String>,
|
||||
dp_aware: bool,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
if dp_aware {
|
||||
// Create futures for all worker creations
|
||||
let worker_futs = urls
|
||||
.iter()
|
||||
.map(|url| Self::create_dp_aware_regular_workers(url, api_key));
|
||||
|
||||
// Execute all futures concurrently and flatten results
|
||||
let all_workers = futures::future::try_join_all(worker_futs)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
Ok(all_workers)
|
||||
} else {
|
||||
Ok(urls
|
||||
.into_iter()
|
||||
.map(|url| Self::create_regular(url))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(WorkerFactory::create_regular)
|
||||
.map(|url| {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Box<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -952,7 +787,7 @@ impl HealthChecker {
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
|
||||
workers: Arc<std::sync::RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
@@ -1037,6 +872,7 @@ pub fn start_health_checker(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::CircuitBreakerConfig;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -1369,7 +1205,11 @@ mod tests {
|
||||
// Test WorkerFactory
|
||||
#[test]
|
||||
fn test_create_regular_worker() {
|
||||
let worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker.url(), "http://regular:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
}
|
||||
@@ -1377,7 +1217,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_create_prefill_worker() {
|
||||
// With bootstrap port
|
||||
let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090));
|
||||
let worker1: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker1.url(), "http://prefill:8080");
|
||||
assert_eq!(
|
||||
worker1.worker_type(),
|
||||
@@ -1387,7 +1233,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Without bootstrap port
|
||||
let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None);
|
||||
let worker2: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(
|
||||
worker2.worker_type(),
|
||||
WorkerType::Prefill {
|
||||
@@ -1398,7 +1250,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_decode_worker() {
|
||||
let worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker.url(), "http://decode:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1424,9 +1280,21 @@ mod tests {
|
||||
#[test]
|
||||
fn test_load_guard_multiple_workers() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
WorkerFactory::create_regular("http://w1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w2:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w3:8080".to_string()),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w3:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect();
|
||||
@@ -1492,8 +1360,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_workers_to_urls() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
WorkerFactory::create_regular("http://w1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w2:8080".to_string()),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
let urls = workers_to_urls(&workers);
|
||||
@@ -1544,8 +1420,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@2");
|
||||
assert_eq!(dp_worker.base_url(), "http://worker1:8080");
|
||||
@@ -1557,14 +1434,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation_prefill() {
|
||||
let dp_worker = DPAwareWorker::new(
|
||||
"http://worker1:8080".to_string(),
|
||||
1,
|
||||
2,
|
||||
WorkerType::Prefill {
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 2)
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
},
|
||||
);
|
||||
})
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@1");
|
||||
assert!(dp_worker.is_dp_aware());
|
||||
@@ -1578,8 +1452,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation_decode() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@0");
|
||||
assert!(dp_worker.is_dp_aware());
|
||||
@@ -1588,8 +1463,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dp_aware_prepare_request() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 3, 8)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
let original_req = serde_json::json!({
|
||||
"prompt": "Hello",
|
||||
@@ -1605,8 +1481,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dp_aware_prepare_request_invalid() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Non-object JSON should fail
|
||||
let invalid_req = serde_json::json!("not an object");
|
||||
@@ -1623,8 +1500,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_endpoint_url() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
dp_worker.endpoint_url("/generate"),
|
||||
@@ -1638,8 +1516,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_delegated_methods() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 2)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Test health status
|
||||
assert!(dp_worker.is_healthy());
|
||||
@@ -1698,23 +1577,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_factory_create_workers_regular() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
|
||||
let workers = WorkerFactory::create_workers(urls, false, &None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(workers.len(), 2);
|
||||
assert!(!workers[0].is_dp_aware());
|
||||
assert!(!workers[1].is_dp_aware());
|
||||
assert_eq!(workers[0].url(), "http://w1:8080");
|
||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||
}
|
||||
|
||||
// ===== Circuit Breaker Integration Tests =====
|
||||
|
||||
#[test]
|
||||
fn test_worker_circuit_breaker() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
@@ -1779,8 +1641,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_circuit_breaker() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker:8080", 0, 2)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Should have circuit breaker
|
||||
assert!(dp_worker.is_available());
|
||||
@@ -1800,9 +1663,23 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_mixed_worker_types() {
|
||||
// Create a mix of worker types
|
||||
let regular = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090));
|
||||
let decode = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
let regular: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
);
|
||||
let prefill: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
let decode: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
let dp_aware_regular =
|
||||
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_aware_prefill = WorkerFactory::create_dp_aware(
|
||||
|
||||
Reference in New Issue
Block a user