[router] refactor worker to builder pattern 3/n (#10647)
This commit is contained in:
@@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
|
|||||||
use serde_json::{from_str, to_string, to_value, to_vec};
|
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
|
||||||
use sglang_router_rs::protocols::spec::{
|
use sglang_router_rs::protocols::spec::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
@@ -12,12 +12,11 @@ use sglang_router_rs::routers::http::pd_types::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
fn create_test_worker() -> BasicWorker {
|
fn create_test_worker() -> BasicWorker {
|
||||||
BasicWorker::new(
|
BasicWorkerBuilder::new("http://test-server:8000")
|
||||||
"http://test-server:8000".to_string(),
|
.worker_type(WorkerType::Prefill {
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: Some(5678),
|
bootstrap_port: Some(5678),
|
||||||
},
|
})
|
||||||
)
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to get bootstrap info from worker
|
// Helper function to get bootstrap info from worker
|
||||||
|
|||||||
@@ -451,7 +451,7 @@ impl Drop for CacheAwarePolicy {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cache_aware_with_balanced_load() {
|
fn test_cache_aware_with_balanced_load() {
|
||||||
@@ -462,14 +462,16 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let policy = CacheAwarePolicy::with_config(config);
|
let policy = CacheAwarePolicy::with_config(config);
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Initialize the policy with workers
|
// Initialize the policy with workers
|
||||||
@@ -497,8 +499,12 @@ mod tests {
|
|||||||
max_tree_size: 10000,
|
max_tree_size: 10000,
|
||||||
});
|
});
|
||||||
|
|
||||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
let worker1 = BasicWorkerBuilder::new("http://w1:8000")
|
||||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
|
let worker2 = BasicWorkerBuilder::new("http://w2:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
|
|
||||||
// Create significant load imbalance
|
// Create significant load imbalance
|
||||||
for _ in 0..20 {
|
for _ in 0..20 {
|
||||||
@@ -524,14 +530,16 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let policy = CacheAwarePolicy::with_config(config);
|
let policy = CacheAwarePolicy::with_config(config);
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
policy.init_workers(&workers);
|
policy.init_workers(&workers);
|
||||||
|
|||||||
@@ -121,23 +121,26 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usi
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_get_healthy_worker_indices() {
|
fn test_get_healthy_worker_indices() {
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
Arc::new(BasicWorker::new(
|
.build(),
|
||||||
"http://w3:8000".to_string(),
|
),
|
||||||
WorkerType::Regular,
|
Arc::new(
|
||||||
)),
|
BasicWorkerBuilder::new("http://w3:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// All healthy initially
|
// All healthy initially
|
||||||
|
|||||||
@@ -119,14 +119,20 @@ impl Default for PowerOfTwoPolicy {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_power_of_two_selection() {
|
fn test_power_of_two_selection() {
|
||||||
let policy = PowerOfTwoPolicy::new();
|
let policy = PowerOfTwoPolicy::new();
|
||||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
let worker1 = BasicWorkerBuilder::new("http://w1:8000")
|
||||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
.worker_type(WorkerType::Regular)
|
||||||
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular);
|
.build();
|
||||||
|
let worker2 = BasicWorkerBuilder::new("http://w2:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
|
let worker3 = BasicWorkerBuilder::new("http://w3:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
|
|
||||||
// Set different loads
|
// Set different loads
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
@@ -157,14 +163,16 @@ mod tests {
|
|||||||
fn test_power_of_two_with_cached_loads() {
|
fn test_power_of_two_with_cached_loads() {
|
||||||
let policy = PowerOfTwoPolicy::new();
|
let policy = PowerOfTwoPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Update cached loads
|
// Update cached loads
|
||||||
@@ -190,10 +198,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_power_of_two_single_worker() {
|
fn test_power_of_two_single_worker() {
|
||||||
let policy = PowerOfTwoPolicy::new();
|
let policy = PowerOfTwoPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
|
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
))];
|
.build(),
|
||||||
|
)];
|
||||||
|
|
||||||
// With single worker, should always select it
|
// With single worker, should always select it
|
||||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||||
|
|||||||
@@ -51,25 +51,28 @@ impl LoadBalancingPolicy for RandomPolicy {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_random_selection() {
|
fn test_random_selection() {
|
||||||
let policy = RandomPolicy::new();
|
let policy = RandomPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
Arc::new(BasicWorker::new(
|
.build(),
|
||||||
"http://w3:8000".to_string(),
|
),
|
||||||
WorkerType::Regular,
|
Arc::new(
|
||||||
)),
|
BasicWorkerBuilder::new("http://w3:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Test multiple selections to ensure randomness
|
// Test multiple selections to ensure randomness
|
||||||
@@ -89,14 +92,16 @@ mod tests {
|
|||||||
fn test_random_with_unhealthy_workers() {
|
fn test_random_with_unhealthy_workers() {
|
||||||
let policy = RandomPolicy::new();
|
let policy = RandomPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Mark first worker as unhealthy
|
// Mark first worker as unhealthy
|
||||||
@@ -111,10 +116,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_random_no_healthy_workers() {
|
fn test_random_no_healthy_workers() {
|
||||||
let policy = RandomPolicy::new();
|
let policy = RandomPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
|
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
))];
|
.build(),
|
||||||
|
)];
|
||||||
|
|
||||||
workers[0].set_healthy(false);
|
workers[0].set_healthy(false);
|
||||||
assert_eq!(policy.select_worker(&workers, None), None);
|
assert_eq!(policy.select_worker(&workers, None), None);
|
||||||
|
|||||||
@@ -60,24 +60,27 @@ impl LoadBalancingPolicy for RoundRobinPolicy {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_round_robin_selection() {
|
fn test_round_robin_selection() {
|
||||||
let policy = RoundRobinPolicy::new();
|
let policy = RoundRobinPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
Arc::new(BasicWorker::new(
|
.build(),
|
||||||
"http://w3:8000".to_string(),
|
),
|
||||||
WorkerType::Regular,
|
Arc::new(
|
||||||
)),
|
BasicWorkerBuilder::new("http://w3:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
|
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
|
||||||
@@ -92,18 +95,21 @@ mod tests {
|
|||||||
fn test_round_robin_with_unhealthy_workers() {
|
fn test_round_robin_with_unhealthy_workers() {
|
||||||
let policy = RoundRobinPolicy::new();
|
let policy = RoundRobinPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
Arc::new(BasicWorker::new(
|
.build(),
|
||||||
"http://w3:8000".to_string(),
|
),
|
||||||
WorkerType::Regular,
|
Arc::new(
|
||||||
)),
|
BasicWorkerBuilder::new("http://w3:8000")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Mark middle worker as unhealthy
|
// Mark middle worker as unhealthy
|
||||||
@@ -120,14 +126,16 @@ mod tests {
|
|||||||
fn test_round_robin_reset() {
|
fn test_round_robin_reset() {
|
||||||
let policy = RoundRobinPolicy::new();
|
let policy = RoundRobinPolicy::new();
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(BasicWorker::new(
|
Arc::new(
|
||||||
"http://w1:8000".to_string(),
|
BasicWorkerBuilder::new("http://w1:8000")
|
||||||
WorkerType::Regular,
|
.worker_type(WorkerType::Regular)
|
||||||
)),
|
.build(),
|
||||||
Arc::new(BasicWorker::new(
|
),
|
||||||
"http://w2:8000".to_string(),
|
Arc::new(
|
||||||
WorkerType::Regular,
|
BasicWorkerBuilder::new("http://w2:8000")
|
||||||
)),
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Advance the counter
|
// Advance the counter
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::grpc::SglangSchedulerClient;
|
use crate::grpc::SglangSchedulerClient;
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
@@ -130,23 +130,22 @@ impl GrpcPDRouter {
|
|||||||
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls
|
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(url, bootstrap_port)| {
|
.map(|(url, bootstrap_port)| {
|
||||||
let worker = BasicWorker::with_connection_mode(
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
url.clone(),
|
.worker_type(WorkerType::Prefill {
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: *bootstrap_port,
|
bootstrap_port: *bootstrap_port,
|
||||||
},
|
})
|
||||||
crate::core::ConnectionMode::Grpc {
|
.connection_mode(crate::core::ConnectionMode::Grpc {
|
||||||
port: *bootstrap_port,
|
port: *bootstrap_port,
|
||||||
},
|
})
|
||||||
)
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
.health_config(HealthConfig {
|
||||||
.with_health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
});
|
})
|
||||||
|
.build();
|
||||||
Arc::new(worker) as Arc<dyn Worker>
|
Arc::new(worker) as Arc<dyn Worker>
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -155,19 +154,18 @@ impl GrpcPDRouter {
|
|||||||
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
|
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
|
||||||
.iter()
|
.iter()
|
||||||
.map(|url| {
|
.map(|url| {
|
||||||
let worker = BasicWorker::with_connection_mode(
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
url.clone(),
|
.worker_type(WorkerType::Decode)
|
||||||
WorkerType::Decode,
|
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||||
crate::core::ConnectionMode::Grpc { port: None },
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
)
|
.health_config(HealthConfig {
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
|
||||||
.with_health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
});
|
})
|
||||||
|
.build();
|
||||||
Arc::new(worker) as Arc<dyn Worker>
|
Arc::new(worker) as Arc<dyn Worker>
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::grpc::SglangSchedulerClient;
|
use crate::grpc::SglangSchedulerClient;
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
@@ -108,20 +108,19 @@ impl GrpcRouter {
|
|||||||
// Move clients from the HashMap to the workers
|
// Move clients from the HashMap to the workers
|
||||||
for url in &worker_urls {
|
for url in &worker_urls {
|
||||||
if let Some(client) = grpc_clients.remove(url) {
|
if let Some(client) = grpc_clients.remove(url) {
|
||||||
let worker = BasicWorker::with_connection_mode(
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
url.clone(),
|
.worker_type(WorkerType::Regular)
|
||||||
WorkerType::Regular,
|
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||||
crate::core::ConnectionMode::Grpc { port: None },
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
)
|
.health_config(HealthConfig {
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
|
||||||
.with_health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
})
|
})
|
||||||
.with_grpc_client(client);
|
.grpc_client(client)
|
||||||
|
.build();
|
||||||
|
|
||||||
workers.push(Arc::new(worker) as Arc<dyn Worker>);
|
workers.push(Arc::new(worker) as Arc<dyn Worker>);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
|
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
||||||
WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
Worker, WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||||
@@ -389,34 +389,35 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Register prefill workers in the registry
|
// Register prefill workers in the registry
|
||||||
for (url, port) in prefill_urls {
|
for (url, port) in prefill_urls {
|
||||||
let worker = BasicWorker::new(
|
let worker = BasicWorkerBuilder::new(url)
|
||||||
url,
|
.worker_type(WorkerType::Prefill {
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: port,
|
bootstrap_port: port,
|
||||||
},
|
})
|
||||||
)
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
.health_config(HealthConfig {
|
||||||
.with_health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
});
|
})
|
||||||
|
.build();
|
||||||
ctx.worker_registry.register(Arc::new(worker));
|
ctx.worker_registry.register(Arc::new(worker));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register decode workers in the registry
|
// Register decode workers in the registry
|
||||||
for url in decode_urls {
|
for url in decode_urls {
|
||||||
let worker = BasicWorker::new(url, WorkerType::Decode)
|
let worker = BasicWorkerBuilder::new(url)
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
.worker_type(WorkerType::Decode)
|
||||||
.with_health_config(HealthConfig {
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(HealthConfig {
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
});
|
})
|
||||||
|
.build();
|
||||||
ctx.worker_registry.register(Arc::new(worker));
|
ctx.worker_registry.register(Arc::new(worker));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2116,7 +2117,7 @@ impl RouterTrait for PDRouter {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::{BasicWorker, WorkerType};
|
use crate::core::WorkerType;
|
||||||
|
|
||||||
fn create_test_pd_router() -> PDRouter {
|
fn create_test_pd_router() -> PDRouter {
|
||||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||||
@@ -2139,7 +2140,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
|
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
|
||||||
let worker = BasicWorker::new(url, worker_type);
|
let worker = BasicWorkerBuilder::new(url)
|
||||||
|
.worker_type(worker_type)
|
||||||
|
.build();
|
||||||
worker.set_healthy(healthy);
|
worker.set_healthy(healthy);
|
||||||
Box::new(worker)
|
Box::new(worker)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
|
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
||||||
WorkerRegistry, WorkerType,
|
Worker, WorkerRegistry, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||||
@@ -87,15 +87,17 @@ impl Router {
|
|||||||
for url in &worker_urls {
|
for url in &worker_urls {
|
||||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||||
// For now, create worker without model_id
|
// For now, create worker without model_id
|
||||||
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
.with_circuit_breaker_config(core_cb_config.clone())
|
.worker_type(WorkerType::Regular)
|
||||||
.with_health_config(HealthConfig {
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(HealthConfig {
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||||
});
|
})
|
||||||
|
.build();
|
||||||
|
|
||||||
let worker_arc = Arc::new(worker);
|
let worker_arc = Arc::new(worker);
|
||||||
ctx.worker_registry.register(worker_arc.clone());
|
ctx.worker_registry.register(worker_arc.clone());
|
||||||
@@ -991,11 +993,10 @@ impl Router {
|
|||||||
}
|
}
|
||||||
info!("Added worker: {}", dp_url);
|
info!("Added worker: {}", dp_url);
|
||||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||||
let new_worker =
|
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
|
||||||
BasicWorker::new(dp_url.to_string(), WorkerType::Regular)
|
.worker_type(WorkerType::Regular)
|
||||||
.with_circuit_breaker_config(
|
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||||
self.circuit_breaker_config.clone(),
|
.build();
|
||||||
);
|
|
||||||
|
|
||||||
let worker_arc = Arc::new(new_worker);
|
let worker_arc = Arc::new(new_worker);
|
||||||
self.worker_registry.register(worker_arc.clone());
|
self.worker_registry.register(worker_arc.clone());
|
||||||
@@ -1028,11 +1029,10 @@ impl Router {
|
|||||||
info!("Added worker: {}", worker_url);
|
info!("Added worker: {}", worker_url);
|
||||||
|
|
||||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||||
let new_worker =
|
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
|
||||||
BasicWorker::new(worker_url.to_string(), WorkerType::Regular)
|
.worker_type(WorkerType::Regular)
|
||||||
.with_circuit_breaker_config(
|
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||||
self.circuit_breaker_config.clone(),
|
.build();
|
||||||
);
|
|
||||||
|
|
||||||
let worker_arc = Arc::new(new_worker);
|
let worker_arc = Arc::new(new_worker);
|
||||||
self.worker_registry.register(worker_arc.clone());
|
self.worker_registry.register(worker_arc.clone());
|
||||||
@@ -1595,8 +1595,12 @@ mod tests {
|
|||||||
));
|
));
|
||||||
|
|
||||||
// Register test workers
|
// Register test workers
|
||||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
|
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
worker_registry.register(Arc::new(worker1));
|
worker_registry.register(Arc::new(worker1));
|
||||||
worker_registry.register(Arc::new(worker2));
|
worker_registry.register(Arc::new(worker2));
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||||
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
|
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -16,13 +16,17 @@ fn test_backward_compatibility_with_empty_model_id() {
|
|||||||
let policy = CacheAwarePolicy::with_config(config);
|
let policy = CacheAwarePolicy::with_config(config);
|
||||||
|
|
||||||
// Create workers with empty model_id (simulating existing routers)
|
// Create workers with empty model_id (simulating existing routers)
|
||||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
// No model_id label - should default to "unknown"
|
// No model_id label - should default to "unknown"
|
||||||
|
|
||||||
let mut labels2 = HashMap::new();
|
let mut labels2 = HashMap::new();
|
||||||
labels2.insert("model_id".to_string(), "unknown".to_string());
|
labels2.insert("model_id".to_string(), "unknown".to_string());
|
||||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||||
.with_labels(labels2);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.labels(labels2)
|
||||||
|
.build();
|
||||||
|
|
||||||
// Add workers - should both go to "default" tree
|
// Add workers - should both go to "default" tree
|
||||||
policy.add_worker(&worker1);
|
policy.add_worker(&worker1);
|
||||||
@@ -53,23 +57,31 @@ fn test_mixed_model_ids() {
|
|||||||
let policy = CacheAwarePolicy::with_config(config);
|
let policy = CacheAwarePolicy::with_config(config);
|
||||||
|
|
||||||
// Create workers with different model_id scenarios
|
// Create workers with different model_id scenarios
|
||||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
// No model_id label - defaults to "unknown" which goes to "default" tree
|
// No model_id label - defaults to "unknown" which goes to "default" tree
|
||||||
|
|
||||||
let mut labels2 = HashMap::new();
|
let mut labels2 = HashMap::new();
|
||||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||||
.with_labels(labels2);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.labels(labels2)
|
||||||
|
.build();
|
||||||
|
|
||||||
let mut labels3 = HashMap::new();
|
let mut labels3 = HashMap::new();
|
||||||
labels3.insert("model_id".to_string(), "unknown".to_string());
|
labels3.insert("model_id".to_string(), "unknown".to_string());
|
||||||
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular)
|
let worker3 = BasicWorkerBuilder::new("http://worker3:8080")
|
||||||
.with_labels(labels3);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.labels(labels3)
|
||||||
|
.build();
|
||||||
|
|
||||||
let mut labels4 = HashMap::new();
|
let mut labels4 = HashMap::new();
|
||||||
labels4.insert("model_id".to_string(), "llama-3".to_string());
|
labels4.insert("model_id".to_string(), "llama-3".to_string());
|
||||||
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular)
|
let worker4 = BasicWorkerBuilder::new("http://worker4:8080")
|
||||||
.with_labels(labels4);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.labels(labels4)
|
||||||
|
.build();
|
||||||
|
|
||||||
// Add all workers
|
// Add all workers
|
||||||
policy.add_worker(&worker1);
|
policy.add_worker(&worker1);
|
||||||
@@ -108,10 +120,14 @@ fn test_remove_worker_by_url_backward_compat() {
|
|||||||
// Create workers with different model_ids
|
// Create workers with different model_ids
|
||||||
let mut labels1 = HashMap::new();
|
let mut labels1 = HashMap::new();
|
||||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular)
|
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||||
.with_labels(labels1);
|
.worker_type(WorkerType::Regular)
|
||||||
|
.labels(labels1)
|
||||||
|
.build();
|
||||||
|
|
||||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.build();
|
||||||
// No model_id label - defaults to "unknown"
|
// No model_id label - defaults to "unknown"
|
||||||
|
|
||||||
// Add workers
|
// Add workers
|
||||||
|
|||||||
Reference in New Issue
Block a user