[router] refactor worker to builder pattern 5/n (#10653)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||
use crate::core::CircuitState;
|
||||
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
@@ -525,23 +525,6 @@ pub struct DPAwareWorker {
|
||||
}
|
||||
|
||||
impl DPAwareWorker {
|
||||
/// Create a new DP-aware worker of any type
|
||||
pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
// Create URL with DP rank suffix for identification
|
||||
let worker_url = format!("{}@{}", base_url, dp_rank);
|
||||
let base_worker = BasicWorkerBuilder::new(worker_url)
|
||||
.worker_type(worker_type)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
base_worker,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new DP-aware worker with a pre-configured base worker
|
||||
/// This is primarily used by the builder pattern
|
||||
pub fn with_base_worker(
|
||||
@@ -661,95 +644,6 @@ impl Worker for DPAwareWorker {
|
||||
pub struct WorkerFactory;
|
||||
|
||||
impl WorkerFactory {
|
||||
/// Create a BasicWorkerBuilder for customizable worker creation
|
||||
pub fn builder(url: impl Into<String>) -> BasicWorkerBuilder {
|
||||
BasicWorkerBuilder::new(url)
|
||||
}
|
||||
|
||||
/// Create a DPAwareWorkerBuilder for customizable DP-aware worker creation
|
||||
pub fn dp_builder(
|
||||
base_url: impl Into<String>,
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
) -> DPAwareWorkerBuilder {
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
}
|
||||
|
||||
/// Create a regular worker
|
||||
pub fn create_regular(url: String) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a prefill worker with optional bootstrap port
|
||||
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a decode worker
|
||||
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a regular worker with custom labels (for multi-router support)
|
||||
pub fn create_regular_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a prefill worker with labels
|
||||
pub fn create_prefill_with_labels(
|
||||
url: String,
|
||||
bootstrap_port: Option<u16>,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a decode worker with labels
|
||||
pub fn create_decode_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
@@ -757,9 +651,13 @@ impl WorkerFactory {
|
||||
dp_size: usize,
|
||||
worker_type: WorkerType,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type))
|
||||
Box::new(
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
.worker_type(worker_type)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
@@ -807,81 +705,18 @@ impl WorkerFactory {
|
||||
|
||||
Ok(dp_size as usize)
|
||||
}
|
||||
|
||||
/// Private helper to create DP-aware workers of any type
|
||||
async fn create_dp_aware_workers_of_type(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
worker_type: WorkerType,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
let dp_size = Self::get_worker_dp_size(url, api_key).await?;
|
||||
|
||||
let workers = (0..dp_size)
|
||||
.map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone()))
|
||||
.collect();
|
||||
|
||||
Ok(workers)
|
||||
}
|
||||
|
||||
/// Create DP-aware regular workers from a single URL
|
||||
pub async fn create_dp_aware_regular_workers(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await
|
||||
}
|
||||
|
||||
/// Create DP-aware prefill workers from a single URL
|
||||
pub async fn create_dp_aware_prefill_workers(
|
||||
url: &str,
|
||||
bootstrap_port: Option<u16>,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port })
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create DP-aware decode workers from a single URL
|
||||
pub async fn create_dp_aware_decode_workers(
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await
|
||||
}
|
||||
|
||||
/// Create workers based on configuration (for regular router)
|
||||
pub async fn create_workers(
|
||||
urls: Vec<String>,
|
||||
dp_aware: bool,
|
||||
api_key: &Option<String>,
|
||||
) -> WorkerResult<Vec<Box<dyn Worker>>> {
|
||||
if dp_aware {
|
||||
// Create futures for all worker creations
|
||||
let worker_futs = urls
|
||||
.iter()
|
||||
.map(|url| Self::create_dp_aware_regular_workers(url, api_key));
|
||||
|
||||
// Execute all futures concurrently and flatten results
|
||||
let all_workers = futures::future::try_join_all(worker_futs)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
Ok(all_workers)
|
||||
} else {
|
||||
Ok(urls
|
||||
.into_iter()
|
||||
.map(|url| Self::create_regular(url))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(WorkerFactory::create_regular)
|
||||
.map(|url| {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Box<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -952,7 +787,7 @@ impl HealthChecker {
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
|
||||
workers: Arc<std::sync::RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
@@ -1037,6 +872,7 @@ pub fn start_health_checker(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::CircuitBreakerConfig;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -1369,7 +1205,11 @@ mod tests {
|
||||
// Test WorkerFactory
|
||||
#[test]
|
||||
fn test_create_regular_worker() {
|
||||
let worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker.url(), "http://regular:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
}
|
||||
@@ -1377,7 +1217,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_create_prefill_worker() {
|
||||
// With bootstrap port
|
||||
let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090));
|
||||
let worker1: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker1.url(), "http://prefill:8080");
|
||||
assert_eq!(
|
||||
worker1.worker_type(),
|
||||
@@ -1387,7 +1233,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Without bootstrap port
|
||||
let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None);
|
||||
let worker2: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(
|
||||
worker2.worker_type(),
|
||||
WorkerType::Prefill {
|
||||
@@ -1398,7 +1250,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_decode_worker() {
|
||||
let worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(worker.url(), "http://decode:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1424,9 +1280,21 @@ mod tests {
|
||||
#[test]
|
||||
fn test_load_guard_multiple_workers() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
WorkerFactory::create_regular("http://w1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w2:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w3:8080".to_string()),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w3:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect();
|
||||
@@ -1492,8 +1360,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_workers_to_urls() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
WorkerFactory::create_regular("http://w1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://w2:8080".to_string()),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new("http://w2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
let urls = workers_to_urls(&workers);
|
||||
@@ -1544,8 +1420,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@2");
|
||||
assert_eq!(dp_worker.base_url(), "http://worker1:8080");
|
||||
@@ -1557,14 +1434,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation_prefill() {
|
||||
let dp_worker = DPAwareWorker::new(
|
||||
"http://worker1:8080".to_string(),
|
||||
1,
|
||||
2,
|
||||
WorkerType::Prefill {
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 2)
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
},
|
||||
);
|
||||
})
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@1");
|
||||
assert!(dp_worker.is_dp_aware());
|
||||
@@ -1578,8 +1452,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation_decode() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build();
|
||||
|
||||
assert_eq!(dp_worker.url(), "http://worker1:8080@0");
|
||||
assert!(dp_worker.is_dp_aware());
|
||||
@@ -1588,8 +1463,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dp_aware_prepare_request() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 3, 8)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
let original_req = serde_json::json!({
|
||||
"prompt": "Hello",
|
||||
@@ -1605,8 +1481,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dp_aware_prepare_request_invalid() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Non-object JSON should fail
|
||||
let invalid_req = serde_json::json!("not an object");
|
||||
@@ -1623,8 +1500,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_endpoint_url() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 4)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
dp_worker.endpoint_url("/generate"),
|
||||
@@ -1638,8 +1516,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_delegated_methods() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 2)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Test health status
|
||||
assert!(dp_worker.is_healthy());
|
||||
@@ -1698,23 +1577,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_factory_create_workers_regular() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
|
||||
let workers = WorkerFactory::create_workers(urls, false, &None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(workers.len(), 2);
|
||||
assert!(!workers[0].is_dp_aware());
|
||||
assert!(!workers[1].is_dp_aware());
|
||||
assert_eq!(workers[0].url(), "http://w1:8080");
|
||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||
}
|
||||
|
||||
// ===== Circuit Breaker Integration Tests =====
|
||||
|
||||
#[test]
|
||||
fn test_worker_circuit_breaker() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
@@ -1779,8 +1641,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_circuit_breaker() {
|
||||
let dp_worker =
|
||||
DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker:8080", 0, 2)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Should have circuit breaker
|
||||
assert!(dp_worker.is_available());
|
||||
@@ -1800,9 +1663,23 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_mixed_worker_types() {
|
||||
// Create a mix of worker types
|
||||
let regular = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090));
|
||||
let decode = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
let regular: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
);
|
||||
let prefill: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
let decode: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
let dp_aware_regular =
|
||||
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_aware_prefill = WorkerFactory::create_dp_aware(
|
||||
|
||||
@@ -424,7 +424,7 @@ pub struct WorkerRegistryStats {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{CircuitBreakerConfig, WorkerFactory};
|
||||
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
@@ -437,10 +437,12 @@ mod tests {
|
||||
labels.insert("priority".to_string(), "50".to_string());
|
||||
labels.insert("cost".to_string(), "0.8".to_string());
|
||||
|
||||
let worker = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels,
|
||||
CircuitBreakerConfig::default(),
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
||||
@@ -470,26 +472,32 @@ mod tests {
|
||||
// Create workers for different models
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker1 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels1,
|
||||
CircuitBreakerConfig::default(),
|
||||
let worker1: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels1)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
);
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker2 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker2:8080".to_string(),
|
||||
labels2,
|
||||
CircuitBreakerConfig::default(),
|
||||
let worker2: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels2)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
);
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("model_id".to_string(), "gpt-4".to_string());
|
||||
let worker3 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker3:8080".to_string(),
|
||||
labels3,
|
||||
CircuitBreakerConfig::default(),
|
||||
let worker3: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://worker3:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels3)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Register workers
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType};
|
||||
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
@@ -208,22 +208,29 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => WorkerFactory::create_prefill_with_labels(
|
||||
config.url.clone(),
|
||||
config.bootstrap_port,
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
Some("decode") => WorkerFactory::create_decode_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
_ => WorkerFactory::create_regular_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
Some("prefill") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: config.bootstrap_port,
|
||||
})
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
Some("decode") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
_ => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
};
|
||||
|
||||
// Register worker
|
||||
|
||||
@@ -4,7 +4,7 @@ mod test_pd_routing {
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
||||
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
@@ -46,11 +46,16 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_worker_types() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||
|
||||
// Test worker creation for prefill servers
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||
match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => {
|
||||
@@ -60,7 +65,11 @@ mod test_pd_routing {
|
||||
}
|
||||
|
||||
// Test worker creation for decode servers
|
||||
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
let decode_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||
match decode_worker.worker_type() {
|
||||
WorkerType::Decode => (),
|
||||
@@ -68,7 +77,11 @@ mod test_pd_routing {
|
||||
}
|
||||
|
||||
// Test regular worker creation
|
||||
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
let regular_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||
match regular_worker.worker_type() {
|
||||
WorkerType::Regular => (),
|
||||
@@ -277,8 +290,13 @@ mod test_pd_routing {
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill1:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
@@ -660,7 +678,7 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
@@ -675,8 +693,13 @@ mod test_pd_routing {
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
@@ -806,8 +829,13 @@ mod test_pd_routing {
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
|
||||
Reference in New Issue
Block a user