Files
sglang/sgl-router/src/core/worker_builder.rs

445 lines
14 KiB
Rust

use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use super::worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
};
use crate::grpc_client::SglangSchedulerClient;
use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder {
url: String,
api_key: Option<String>,
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
health_config: HealthConfig,
circuit_breaker_config: CircuitBreakerConfig,
grpc_client: Option<SglangSchedulerClient>,
}
impl BasicWorkerBuilder {
/// Create a new builder with only the URL
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Create a new builder with URL and worker type (for backwards compatibility)
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
self
}
/// Set the connection mode (HTTP or gRPC)
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.connection_mode = mode;
self
}
/// Set labels for worker identification
pub fn labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
/// Add a single label
pub fn label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.insert(key.into(), value.into());
self
}
/// Set health check configuration
pub fn health_config(mut self, config: HealthConfig) -> Self {
self.health_config = config;
self
}
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker_config = config;
self
}
/// Set gRPC client for gRPC workers
pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(client);
self
}
/// Build the BasicWorker instance
pub fn build(self) -> BasicWorker {
use std::sync::{
atomic::{AtomicBool, AtomicUsize},
Arc,
};
use tokio::sync::{Mutex, RwLock};
let metadata = WorkerMetadata {
url: self.url.clone(),
api_key: self.api_key,
worker_type: self.worker_type,
connection_mode: self.connection_mode,
labels: self.labels,
health_config: self.health_config,
};
let grpc_client = Arc::new(RwLock::new(
self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
));
BasicWorker {
metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config),
grpc_client,
}
}
}
/// Builder for creating DPAwareWorker instances with fluent API
pub struct DPAwareWorkerBuilder {
base_url: String,
api_key: Option<String>,
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
health_config: HealthConfig,
circuit_breaker_config: CircuitBreakerConfig,
grpc_client: Option<SglangSchedulerClient>,
}
impl DPAwareWorkerBuilder {
/// Create a new DP-aware worker builder
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Create a new DP-aware worker builder with worker type (for backwards compatibility)
pub fn new_with_type(
base_url: impl Into<String>,
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
self
}
/// Set the connection mode (HTTP or gRPC)
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.connection_mode = mode;
self
}
/// Set labels for worker identification
pub fn labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
/// Add a single label
pub fn label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.insert(key.into(), value.into());
self
}
/// Set health check configuration
pub fn health_config(mut self, config: HealthConfig) -> Self {
self.health_config = config;
self
}
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker_config = config;
self
}
/// Set gRPC client for gRPC workers
pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(client);
self
}
/// Build the DPAwareWorker instance
pub fn build(self) -> DPAwareWorker {
let worker_url = format!("{}@{}", self.base_url, self.dp_rank);
let mut builder = BasicWorkerBuilder::new(worker_url)
.worker_type(self.worker_type)
.connection_mode(self.connection_mode)
.labels(self.labels)
.health_config(self.health_config)
.circuit_breaker_config(self.circuit_breaker_config);
if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client);
}
if let Some(api_key) = self.api_key {
builder = builder.api_key(api_key);
}
let base_worker = builder.build();
DPAwareWorker::with_base_worker(base_worker, self.base_url, self.dp_rank, self.dp_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::worker::Worker;
use std::time::Duration;
#[test]
fn test_basic_worker_builder_minimal() {
let worker = BasicWorkerBuilder::new("http://localhost:8080").build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
assert_eq!(worker.connection_mode(), ConnectionMode::Http);
assert!(worker.is_healthy());
}
#[test]
fn test_basic_worker_builder_with_type() {
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Decode)
.build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(worker.worker_type(), WorkerType::Decode);
assert_eq!(worker.connection_mode(), ConnectionMode::Http);
assert!(worker.is_healthy());
}
#[test]
fn test_basic_worker_builder_full() {
let mut labels = HashMap::new();
labels.insert("env".to_string(), "prod".to_string());
labels.insert("region".to_string(), "us-east".to_string());
let health_config = HealthConfig {
endpoint: "/health".to_string(),
timeout_secs: 30,
check_interval_secs: 60,
failure_threshold: 3,
success_threshold: 2,
};
let cb_config = CircuitBreakerConfig {
failure_threshold: 10,
success_threshold: 5,
timeout_duration: Duration::from_millis(2000),
window_duration: Duration::from_millis(30000),
};
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: None,
})
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
.labels(labels.clone())
.health_config(health_config.clone())
.circuit_breaker_config(cb_config)
.build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(
worker.worker_type(),
WorkerType::Prefill {
bootstrap_port: None
}
);
assert_eq!(
worker.connection_mode(),
ConnectionMode::Grpc { port: Some(50051) }
);
assert_eq!(worker.metadata().labels, labels);
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
);
assert_eq!(
worker.metadata().health_config.timeout_secs,
health_config.timeout_secs
);
assert_eq!(
worker.metadata().health_config.check_interval_secs,
health_config.check_interval_secs
);
assert_eq!(
worker.metadata().health_config.failure_threshold,
health_config.failure_threshold
);
assert_eq!(
worker.metadata().health_config.success_threshold,
health_config.success_threshold
);
}
#[test]
fn test_basic_worker_builder_with_single_label() {
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Decode)
.label("env", "staging")
.label("version", "v1.2.3")
.build();
assert_eq!(
worker.metadata().labels.get("env"),
Some(&"staging".to_string())
);
assert_eq!(
worker.metadata().labels.get("version"),
Some(&"v1.2.3".to_string())
);
}
#[test]
fn test_dp_aware_worker_builder_minimal() {
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 2, 8).build();
assert_eq!(worker.url(), "http://localhost:8080@2");
assert_eq!(worker.dp_rank(), Some(2));
assert_eq!(worker.dp_size(), Some(8));
assert_eq!(worker.worker_type(), WorkerType::Regular);
}
#[test]
fn test_dp_aware_worker_builder_full() {
let mut labels = HashMap::new();
labels.insert("cluster".to_string(), "main".to_string());
let health_config = HealthConfig {
endpoint: "/status".to_string(),
timeout_secs: 20,
check_interval_secs: 45,
failure_threshold: 5,
success_threshold: 3,
};
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 3, 16)
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.connection_mode(ConnectionMode::Http)
.labels(labels.clone())
.health_config(health_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://localhost:8080@3");
assert_eq!(worker.dp_rank(), Some(3));
assert_eq!(worker.dp_size(), Some(16));
assert_eq!(worker.metadata().labels, labels);
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
);
assert_eq!(
worker.metadata().health_config.timeout_secs,
health_config.timeout_secs
);
assert_eq!(
worker.metadata().health_config.check_interval_secs,
health_config.check_interval_secs
);
assert_eq!(
worker.metadata().health_config.failure_threshold,
health_config.failure_threshold
);
assert_eq!(
worker.metadata().health_config.success_threshold,
health_config.success_threshold
);
}
#[test]
fn test_dp_aware_worker_with_grpc() {
let worker = DPAwareWorkerBuilder::new("grpc://cluster.local", 1, 4)
.worker_type(WorkerType::Decode)
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
.label("transport", "grpc")
.build();
assert_eq!(worker.url(), "grpc://cluster.local@1");
assert_eq!(worker.dp_rank(), Some(1));
assert_eq!(worker.dp_size(), Some(4));
assert_eq!(worker.worker_type(), WorkerType::Decode);
assert_eq!(
worker.connection_mode(),
ConnectionMode::Grpc { port: Some(50051) }
);
assert_eq!(
worker.metadata().labels.get("transport"),
Some(&"grpc".to_string())
);
}
}