[router] refactor router and worker management 3/n (#10727)

This commit is contained in:
Simo Lin
2025-09-22 15:17:50 -04:00
committed by GitHub
parent 60dbbd086a
commit 97c3823931
25 changed files with 1427 additions and 2540 deletions

View File

@@ -83,14 +83,13 @@ impl CircuitBreaker {
/// Check if a request can be executed
pub fn can_execute(&self) -> bool {
// First check if we need to transition from Open to HalfOpen
self.check_and_update_state();
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => true,
CircuitState::Open => false,
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
CircuitState::HalfOpen => true,
}
}
@@ -114,22 +113,17 @@ impl CircuitBreaker {
self.total_successes.fetch_add(1, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Release);
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
let current_state = *self.state.read().unwrap();
match current_state {
CircuitState::HalfOpen => {
// Check if we've reached the success threshold to close the circuit
if successes >= self.config.success_threshold {
self.transition_to(CircuitState::Closed);
}
}
CircuitState::Closed => {
// Already closed, nothing to do
}
CircuitState::Closed => {}
CircuitState::Open => {
// Shouldn't happen, but if it does, stay open
tracing::warn!("Success recorded while circuit is open");
}
}
@@ -140,9 +134,7 @@ impl CircuitBreaker {
self.total_failures.fetch_add(1, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Release);
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
// Update last failure time
{
let mut last_failure = self.last_failure_time.write().unwrap();
*last_failure = Some(Instant::now());
@@ -152,18 +144,14 @@ impl CircuitBreaker {
match current_state {
CircuitState::Closed => {
// Check if we've reached the failure threshold to open the circuit
if failures >= self.config.failure_threshold {
self.transition_to(CircuitState::Open);
}
}
CircuitState::HalfOpen => {
// Single failure in half-open state reopens the circuit
self.transition_to(CircuitState::Open);
}
CircuitState::Open => {
// Already open, nothing to do
}
CircuitState::Open => {}
}
}
@@ -172,7 +160,6 @@ impl CircuitBreaker {
let current_state = *self.state.read().unwrap();
if current_state == CircuitState::Open {
// Check if timeout has expired
let last_change = *self.last_state_change.read().unwrap();
if last_change.elapsed() >= self.config.timeout_duration {
self.transition_to(CircuitState::HalfOpen);
@@ -188,11 +175,9 @@ impl CircuitBreaker {
if old_state != new_state {
*state = new_state;
// Update last state change time
let mut last_change = self.last_state_change.write().unwrap();
*last_change = Instant::now();
// Reset counters based on transition
match new_state {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::Release);
@@ -218,7 +203,6 @@ impl CircuitBreaker {
CircuitState::HalfOpen => "half_open",
};
info!("Circuit breaker state transition: {} -> {}", from, to);
// Transition metrics are recorded at the worker level where the worker label is known
}
}
@@ -533,7 +517,6 @@ mod tests {
let cb = Arc::new(CircuitBreaker::new());
let mut handles = vec![];
// Spawn threads that record failures
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
@@ -544,12 +527,10 @@ mod tests {
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
// Should have recorded 1000 failures
assert_eq!(cb.total_failures(), 1000);
}
}

View File

@@ -122,7 +122,6 @@ mod tests {
let error = WorkerError::WorkerNotFound {
url: "http://test".to_string(),
};
// Verify it implements Error trait
let _: &dyn Error = &error;
assert!(error.source().is_none());
}
@@ -135,11 +134,9 @@ mod tests {
#[test]
fn test_worker_result_type_alias() {
// Test Ok variant
let result: WorkerResult<i32> = Ok(42);
assert!(matches!(result, Ok(42)));
// Test Err variant
let error = WorkerError::WorkerNotFound {
url: "test".to_string(),
};
@@ -149,7 +146,6 @@ mod tests {
#[test]
fn test_empty_url_handling() {
// Test empty URLs in error variants
let error1 = WorkerError::HealthCheckFailed {
url: "".to_string(),
reason: "No connection".to_string(),
@@ -173,7 +169,6 @@ mod tests {
#[test]
fn test_special_characters_in_messages() {
// Test with special characters
let error = WorkerError::InvalidConfiguration {
message: "Invalid JSON: {\"error\": \"test\"}".to_string(),
};
@@ -182,7 +177,6 @@ mod tests {
"Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}"
);
// Test with unicode
let error2 = WorkerError::HealthCheckFailed {
url: "http://测试:8080".to_string(),
reason: "连接被拒绝".to_string(),
@@ -207,10 +201,8 @@ mod tests {
);
}
// Mock reqwest error for testing conversion
#[test]
fn test_reqwest_error_conversion() {
// Test that NetworkError is the correct variant
let network_error = WorkerError::NetworkError {
url: "http://example.com".to_string(),
error: "connection timeout".to_string(),
@@ -227,8 +219,6 @@ mod tests {
#[test]
fn test_error_equality() {
// WorkerError doesn't implement PartialEq, but we can test that
// the same error construction produces the same display output
let error1 = WorkerError::WorkerNotFound {
url: "http://test".to_string(),
};

View File

@@ -12,9 +12,9 @@ pub mod retry;
pub mod token_bucket;
pub mod worker;
pub mod worker_builder;
pub mod worker_manager;
pub mod worker_registry;
// Re-export commonly used types at the module level
pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
};
@@ -25,4 +25,5 @@ pub use worker::{
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_manager::{DpInfo, ServerInfo, WorkerManager};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};

View File

@@ -25,14 +25,12 @@ pub struct BackoffCalculator;
impl BackoffCalculator {
/// Calculate backoff delay for a given attempt index (0-based).
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
// Base exponential backoff
let pow = config.backoff_multiplier.powi(attempt as i32);
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
if delay_ms > config.max_backoff_ms {
delay_ms = config.max_backoff_ms;
}
// Apply jitter in range [-j, +j]
let jitter = config.jitter_factor.clamp(0.0, 1.0);
if jitter > 0.0 {
let mut rng = rand::rng();
@@ -77,14 +75,12 @@ impl RetryExecutor {
match operation(attempt).await {
Ok(val) => return Ok(val),
Err(_) => {
// Use the number of failures so far (0-indexed) to compute delay,
// so the first retry uses `initial_backoff_ms`.
let is_last = attempt + 1 >= max;
if is_last {
return Err(RetryError::MaxRetriesExceeded);
}
let delay = BackoffCalculator::calculate_delay(config, attempt);
attempt += 1; // advance to the next attempt after computing delay
attempt += 1;
tokio::time::sleep(delay).await;
}
}
@@ -144,14 +140,11 @@ impl RetryExecutor {
}
if is_last {
// Exhausted retries
on_exhausted();
return response;
}
// Backoff before next attempt
let next_attempt = attempt + 1;
// Compute delay based on the number of failures so far (0-indexed)
let delay = BackoffCalculator::calculate_delay(config, attempt);
debug!(
attempt = attempt,
@@ -194,22 +187,18 @@ mod tests {
backoff_multiplier: 2.0,
jitter_factor: 0.0,
};
// attempt=0 => 100ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 0),
Duration::from_millis(100)
);
// attempt=1 => 200ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 1),
Duration::from_millis(200)
);
// attempt=2 => 400ms -> capped to 250ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 2),
Duration::from_millis(250)
);
// large attempt still capped
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 10),
Duration::from_millis(250)
@@ -225,7 +214,6 @@ mod tests {
backoff_multiplier: 2.0,
jitter_factor: 0.5,
};
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
let base = 400.0;
for _ in 0..50 {
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
@@ -261,7 +249,7 @@ mod tests {
assert!(res.is_ok());
assert_eq!(res.unwrap(), 42);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
assert_eq!(calls.load(Ordering::Relaxed), 3);
}
#[tokio::test]
@@ -309,7 +297,7 @@ mod tests {
}
}
},
|res, _attempt| !res.status().is_success(), // retry until success
|res, _attempt| !res.status().is_success(),
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
@@ -326,7 +314,7 @@ mod tests {
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
assert_eq!(calls.load(Ordering::Relaxed), 3);
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
}
@@ -347,7 +335,7 @@ mod tests {
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
}
},
|_res, _attempt| false, // never retry
|_res, _attempt| false,
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
@@ -385,7 +373,7 @@ mod tests {
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
}
},
|_res, _attempt| true, // keep retrying
|_res, _attempt| true,
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {

View File

@@ -32,16 +32,11 @@ impl TokenBucket {
let capacity = capacity as f64;
let refill_rate = refill_rate as f64;
// Ensure refill_rate is not zero to prevent division by zero
let refill_rate = if refill_rate > 0.0 {
refill_rate
} else {
1.0 // Default to 1 token per second if zero
};
let refill_rate = if refill_rate > 0.0 { refill_rate } else { 1.0 };
Self {
inner: Arc::new(Mutex::new(TokenBucketInner {
tokens: capacity, // Start full
tokens: capacity,
last_refill: Instant::now(),
})),
notify: Arc::new(Notify::new()),
@@ -54,7 +49,6 @@ impl TokenBucket {
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
let mut inner = self.inner.lock().await;
// Refill tokens based on elapsed time
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
@@ -82,12 +76,10 @@ impl TokenBucket {
/// Acquire tokens, waiting if necessary
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
// First try to acquire immediately
if self.try_acquire(tokens).await.is_ok() {
return Ok(());
}
// Calculate wait time
let wait_time = {
let inner = self.inner.lock().await;
let tokens_needed = tokens - inner.tokens;
@@ -100,15 +92,12 @@ impl TokenBucket {
wait_time, tokens
);
// Wait for tokens to be available
tokio::time::timeout(wait_time, async {
loop {
// Check if we can acquire now
if self.try_acquire(tokens).await.is_ok() {
return;
}
// Wait for notification or small interval
tokio::select! {
_ = self.notify.notified() => {},
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
@@ -144,7 +133,6 @@ impl TokenBucket {
pub async fn available_tokens(&self) -> f64 {
let mut inner = self.inner.lock().await;
// Refill before checking
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
@@ -162,33 +150,26 @@ mod tests {
#[tokio::test]
async fn test_token_bucket_basic() {
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
let bucket = TokenBucket::new(10, 5);
// Should succeed - bucket starts full
assert!(bucket.try_acquire(5.0).await.is_ok());
assert!(bucket.try_acquire(5.0).await.is_ok());
// Should fail - no tokens left
assert!(bucket.try_acquire(1.0).await.is_err());
// Wait for refill
tokio::time::sleep(Duration::from_millis(300)).await;
// Should have ~1.5 tokens now
assert!(bucket.try_acquire(1.0).await.is_ok());
}
#[tokio::test]
async fn test_token_bucket_refill() {
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
let bucket = TokenBucket::new(10, 10);
// Use all tokens
assert!(bucket.try_acquire(10.0).await.is_ok());
// Wait for partial refill
tokio::time::sleep(Duration::from_millis(500)).await;
// Should have ~5 tokens
let available = bucket.available_tokens().await;
assert!((4.0..=6.0).contains(&available));
}

View File

@@ -11,10 +11,9 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex;
// Shared HTTP client for worker operations (health checks, server info, etc.)
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create worker HTTP client")
});
@@ -43,7 +42,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Synchronous health check wrapper (for compatibility)
fn check_health(&self) -> WorkerResult<()> {
// Use a small runtime for synchronous contexts
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
@@ -64,10 +62,7 @@ pub trait Worker: Send + Sync + fmt::Debug {
fn decrement_load(&self);
/// Reset the load counter to 0 (for sync/recovery)
fn reset_load(&self) {
// Default implementation - does nothing
// Workers that track load should override this
}
fn reset_load(&self) {}
/// Get the number of processed requests
fn processed_requests(&self) -> usize;
@@ -88,11 +83,9 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Record the outcome of a request to this worker
fn record_outcome(&self, success: bool) {
// Record outcome-level metric with worker label
let outcome_str = if success { "success" } else { "failure" };
RouterMetrics::record_cb_outcome(self.url(), outcome_str);
// Record into circuit breaker and infer state change for metrics
let before = self.circuit_breaker().state();
self.circuit_breaker().record_outcome(success);
let after = self.circuit_breaker().state();
@@ -119,8 +112,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
RouterMetrics::set_cb_state(self.url(), state_code);
}
// === DP-aware methods ===
/// Check if this worker is DP-aware
fn is_dp_aware(&self) -> bool {
false
@@ -156,8 +147,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
true
}
// === Multi-router support ===
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
@@ -356,14 +345,12 @@ impl fmt::Debug for BasicWorker {
impl BasicWorker {
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let parts: Vec<&str> = self.url().split('@').collect();
if parts.len() != 2 {
return Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
});
}
// Ensure the second part (the dp_rank) can be parsed as an integer
match parts[1].parse::<usize>() {
Ok(_) => Ok(parts[0]),
Err(_) => Err(WorkerError::InvalidUrl {
@@ -408,19 +395,22 @@ impl Worker for BasicWorker {
let health_result = match &self.metadata.connection_mode {
ConnectionMode::Http => {
// Perform HTTP health check
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
let mut request = WORKER_CLIENT.get(&health_url).timeout(timeout);
if let Some(ref api_key) = self.metadata.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
match request.send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
ConnectionMode::Grpc { .. } => {
// Perform gRPC health check
if let Some(grpc_client) = &self.grpc_client {
let mut client = grpc_client.lock().await;
match client.health_check().await {
@@ -449,11 +439,9 @@ impl Worker for BasicWorker {
};
if health_result {
// Health check succeeded
self.consecutive_failures.store(0, Ordering::Release);
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
// Mark healthy if we've reached the success threshold
if !self.is_healthy()
&& successes >= self.metadata.health_config.success_threshold as usize
{
@@ -462,11 +450,9 @@ impl Worker for BasicWorker {
}
Ok(())
} else {
// Health check failed
self.consecutive_successes.store(0, Ordering::Release);
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
// Mark unhealthy if we've reached the failure threshold
if self.is_healthy()
&& failures >= self.metadata.health_config.failure_threshold as usize
{
@@ -576,7 +562,6 @@ impl Worker for DPAwareWorker {
}
async fn check_health_async(&self) -> WorkerResult<()> {
// Delegate to the base worker's health check logic
self.base_worker.check_health_async().await
}
@@ -612,8 +597,6 @@ impl Worker for DPAwareWorker {
self.base_worker.circuit_breaker()
}
// DP-aware specific implementations
fn is_dp_aware(&self) -> bool {
true
}
@@ -631,7 +614,6 @@ impl Worker for DPAwareWorker {
}
async fn prepare_request(&self, mut req: serde_json::Value) -> WorkerResult<serde_json::Value> {
// Inject data_parallel_rank into the request
if let Some(map) = req.as_object_mut() {
map.insert(
"data_parallel_rank".to_string(),
@@ -646,7 +628,6 @@ impl Worker for DPAwareWorker {
}
fn endpoint_url(&self, route: &str) -> String {
// Use base URL for actual requests
format!("{}{}", self.base_url, route)
}
}
@@ -670,53 +651,52 @@ impl WorkerFactory {
}
Box::new(builder.build())
}
#[allow(dead_code)]
/// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
if let Some(key) = &api_key {
req_builder = req_builder.bearer_auth(key);
}
/// Static health validation before creating a worker
/// This replaces wait_for_worker_health in handlers
pub async fn validate_health(url: &str, timeout_secs: u64) -> WorkerResult<()> {
use std::time::Instant;
let response = req_builder
.send()
.await
.map_err(|e| WorkerError::NetworkError {
url: url.to_string(),
error: e.to_string(),
})?;
let start_time = Instant::now();
let timeout = std::time::Duration::from_secs(timeout_secs);
if !response.status().is_success() {
return Err(WorkerError::NetworkError {
url: url.to_string(),
error: format!("Server returned: {}", response.status()),
});
}
let info: serde_json::Value =
response
.json()
.await
.map_err(|e| WorkerError::NetworkError {
loop {
if start_time.elapsed() > timeout {
return Err(WorkerError::HealthCheckFailed {
url: url.to_string(),
error: format!("Failed to parse JSON: {}", e),
})?;
reason: format!(
"Timeout {}s waiting for worker to become healthy",
timeout_secs
),
});
}
let dp_size = info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| WorkerError::InvalidConfiguration {
message: "dp_size not found in server info".to_string(),
})?;
// Note: This static function doesn't have access to worker's API key
// API key authentication is handled in the worker instance's check_health_async method
match WORKER_CLIENT
.get(format!("{}/health", url))
.timeout(std::time::Duration::from_secs(5))
.send()
.await
{
Ok(res) if res.status().is_success() => {
tracing::info!("Worker {} is healthy", url);
return Ok(());
}
Ok(res) => {
tracing::warn!(
"Worker {} health check failed with status: {}",
url,
res.status()
);
}
Err(e) => {
tracing::warn!("Failed to contact worker {}: {}", url, e);
}
}
if dp_size > usize::MAX as u64 {
return Err(WorkerError::InvalidConfiguration {
message: format!("dp_size is too large: {}", dp_size),
});
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
Ok(dp_size as usize)
}
}
@@ -893,7 +873,6 @@ mod tests {
use std::thread;
use std::time::Duration;
// Test WorkerType
#[test]
fn test_worker_type_display() {
assert_eq!(WorkerType::Regular.to_string(), "Regular");
@@ -945,7 +924,6 @@ mod tests {
assert_eq!(original, cloned);
}
// Test HealthConfig
#[test]
fn test_health_config_default() {
let config = HealthConfig::default();
@@ -972,13 +950,11 @@ mod tests {
assert_eq!(config.success_threshold, 3);
}
// Test BasicWorker
#[test]
fn test_basic_worker_creation() {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
@@ -1016,7 +992,6 @@ mod tests {
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.health_config(custom_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
@@ -1024,13 +999,11 @@ mod tests {
assert_eq!(worker.metadata().health_config.endpoint, "/custom-health");
}
// Test Worker trait implementation
#[test]
fn test_worker_url() {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://worker1:8080");
}
@@ -1040,7 +1013,6 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(regular.worker_type(), WorkerType::Regular);
@@ -1048,7 +1020,6 @@ mod tests {
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.api_key("test_api_key")
.build();
assert_eq!(
prefill.worker_type(),
@@ -1059,7 +1030,6 @@ mod tests {
let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode)
.api_key("test_api_key")
.build();
assert_eq!(decode.worker_type(), WorkerType::Decode);
}
@@ -1071,14 +1041,11 @@ mod tests {
.worker_type(WorkerType::Regular)
.build();
// Initial state is healthy
assert!(worker.is_healthy());
// Set unhealthy
worker.set_healthy(false);
assert!(!worker.is_healthy());
// Set healthy again
worker.set_healthy(true);
assert!(worker.is_healthy());
}
@@ -1088,31 +1055,24 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// Initial load is 0
assert_eq!(worker.load(), 0);
// Increment once
worker.increment_load();
assert_eq!(worker.load(), 1);
// Increment twice more
worker.increment_load();
worker.increment_load();
assert_eq!(worker.load(), 3);
// Decrement once
worker.decrement_load();
assert_eq!(worker.load(), 2);
// Decrement to 0
worker.decrement_load();
worker.decrement_load();
assert_eq!(worker.load(), 0);
// Decrement below 0 should stay at 0
worker.decrement_load();
assert_eq!(worker.load(), 0);
}
@@ -1124,17 +1084,14 @@ mod tests {
.worker_type(WorkerType::Regular)
.build();
// Initial count is 0
assert_eq!(worker.processed_requests(), 0);
// Increment multiple times
for i in 1..=100 {
worker.increment_processed();
assert_eq!(worker.processed_requests(), i);
}
}
// Test concurrent operations
#[tokio::test]
async fn test_concurrent_load_increments() {
use crate::core::BasicWorkerBuilder;
@@ -1146,7 +1103,6 @@ mod tests {
let mut handles = vec![];
// Spawn 100 tasks incrementing load
for _ in 0..100 {
let worker_clone = Arc::clone(&worker);
let handle = tokio::spawn(async move {
@@ -1155,12 +1111,10 @@ mod tests {
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Final count should be 100
assert_eq!(worker.load(), 100);
}
@@ -1173,7 +1127,6 @@ mod tests {
.build(),
);
// Set initial load to 100
for _ in 0..100 {
worker.increment_load();
}
@@ -1181,7 +1134,6 @@ mod tests {
let mut handles = vec![];
// Spawn 100 tasks decrementing load
for _ in 0..100 {
let worker_clone = Arc::clone(&worker);
let handle = tokio::spawn(async move {
@@ -1190,12 +1142,10 @@ mod tests {
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Final count should be 0
assert_eq!(worker.load(), 0);
}
@@ -1210,7 +1160,6 @@ mod tests {
let mut handles = vec![];
// Spawn threads randomly setting health status
for i in 0..100 {
let worker_clone = Arc::clone(&worker);
let handle = tokio::spawn(async move {
@@ -1220,13 +1169,11 @@ mod tests {
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
}
// Test WorkerFactory
#[test]
fn test_create_regular_worker() {
let worker: Box<dyn Worker> = Box::new(
@@ -1240,7 +1187,6 @@ mod tests {
#[test]
fn test_create_prefill_worker() {
// With bootstrap port
let worker1: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
@@ -1256,7 +1202,6 @@ mod tests {
}
);
// Without bootstrap port
let worker2: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
@@ -1283,7 +1228,6 @@ mod tests {
assert_eq!(worker.worker_type(), WorkerType::Decode);
}
// Test WorkerLoadGuard
#[test]
fn test_load_guard_single_worker() {
use crate::core::BasicWorkerBuilder;
@@ -1297,7 +1241,6 @@ mod tests {
assert_eq!(worker.load(), 1);
}
// Guard dropped, load decremented
assert_eq!(worker.load(), 0);
}
@@ -1325,13 +1268,11 @@ mod tests {
{
let _guard = WorkerLoadGuard::new_multi(worker_refs);
// All loads incremented
assert_eq!(workers[0].load(), 1);
assert_eq!(workers[1].load(), 1);
assert_eq!(workers[2].load(), 1);
}
// All loads decremented
assert_eq!(workers[0].load(), 0);
assert_eq!(workers[1].load(), 0);
assert_eq!(workers[2].load(), 0);
@@ -1347,29 +1288,21 @@ mod tests {
);
assert_eq!(worker.load(), 0);
// Clone for use inside catch_unwind
let worker_clone = Arc::clone(&worker);
// Use AssertUnwindSafe wrapper for the test
// This is safe because we're only testing the load counter behavior,
// not the grpc_client which is None for HTTP workers
use std::panic::AssertUnwindSafe;
// This will panic, but the guard should still clean up
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
assert_eq!(worker_clone.load(), 1);
panic!("Test panic");
}));
// Verify panic occurred
assert!(result.is_err());
// Load should be decremented even after panic
assert_eq!(worker.load(), 0);
}
// Test helper functions
#[test]
fn test_urls_to_workers() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
@@ -1400,23 +1333,17 @@ mod tests {
assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]);
}
// Test synchronous health check wrapper
#[test]
fn test_check_health_sync_wrapper() {
// We can't easily test the actual HTTP call without mocking,
// but we can verify the sync wrapper works
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// This will fail because there's no server at this URL,
// but it tests that the sync wrapper doesn't panic
let result = worker.check_health();
assert!(result.is_err());
}
// Performance test for load counter
#[test]
fn test_load_counter_performance() {
use crate::core::BasicWorkerBuilder;
@@ -1436,12 +1363,9 @@ mod tests {
let ops_per_sec = iterations as f64 / duration.as_secs_f64();
println!("Load counter operations per second: {:.0}", ops_per_sec);
// Should be well over 1M ops/sec
assert!(ops_per_sec > 1_000_000.0);
}
// ===== Tests for DPAwareWorker =====
#[test]
fn test_dp_aware_worker_creation() {
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4)
@@ -1562,8 +1486,6 @@ mod tests {
assert_eq!(dp_worker.processed_requests(), 1);
}
// ===== Tests for WorkerFactory async methods =====
#[tokio::test]
async fn test_factory_create_dp_aware() {
let worker = WorkerFactory::create_dp_aware(
@@ -1610,26 +1532,21 @@ mod tests {
.worker_type(WorkerType::Regular)
.build();
// Initial state should be available
assert!(worker.is_available());
assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
// Record some failures
worker.record_outcome(false);
worker.record_outcome(false);
// Still available (default threshold is 5)
assert!(worker.is_available());
// Record more failures to open circuit
worker.record_outcome(false);
worker.record_outcome(false);
worker.record_outcome(false);
// Circuit should be open, worker not available
assert!(!worker.is_available());
assert!(worker.is_healthy()); // Still healthy
assert!(!worker.circuit_breaker().can_execute()); // But circuit is open
assert!(worker.is_healthy());
assert!(!worker.circuit_breaker().can_execute());
}
#[test]
@@ -1647,20 +1564,16 @@ mod tests {
.circuit_breaker_config(config)
.build();
// Should open after 2 failures
worker.record_outcome(false);
assert!(worker.is_available());
worker.record_outcome(false);
assert!(!worker.is_available());
// Wait for timeout
thread::sleep(Duration::from_millis(150));
// Should be half-open
assert!(worker.is_available());
assert_eq!(worker.circuit_breaker().state(), CircuitState::HalfOpen);
// Success should close it
worker.record_outcome(true);
assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
}
@@ -1671,24 +1584,18 @@ mod tests {
.worker_type(WorkerType::Regular)
.build();
// Should have circuit breaker
assert!(dp_worker.is_available());
// Record failures
for _ in 0..5 {
dp_worker.record_outcome(false);
}
// Should not be available
assert!(!dp_worker.is_available());
assert_eq!(dp_worker.circuit_breaker().state(), CircuitState::Open);
}
// ===== Integration tests =====
#[tokio::test]
async fn test_mixed_worker_types() {
// Create a mix of worker types
let regular: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://regular:8080")
.worker_type(WorkerType::Regular)
@@ -1739,22 +1646,19 @@ mod tests {
dp_aware_decode,
];
// Test that they all implement Worker trait properly
for worker in &workers {
assert!(worker.is_healthy());
assert_eq!(worker.load(), 0);
assert_eq!(worker.processed_requests(), 0);
}
// Test specific behaviors
assert!(!workers[0].is_dp_aware()); // regular
assert!(!workers[1].is_dp_aware()); // prefill
assert!(!workers[2].is_dp_aware()); // decode
assert!(workers[3].is_dp_aware()); // dp_aware_regular
assert!(workers[4].is_dp_aware()); // dp_aware_prefill
assert!(workers[5].is_dp_aware()); // dp_aware_decode
assert!(!workers[0].is_dp_aware());
assert!(!workers[1].is_dp_aware());
assert!(!workers[2].is_dp_aware());
assert!(workers[3].is_dp_aware());
assert!(workers[4].is_dp_aware());
assert!(workers[5].is_dp_aware());
// Test worker types
assert_eq!(workers[0].worker_type(), WorkerType::Regular);
assert_eq!(
workers[1].worker_type(),

View File

@@ -7,10 +7,7 @@ use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder {
// Required fields
url: String,
// Optional fields with defaults
api_key: Option<String>,
worker_type: WorkerType,
connection_mode: ConnectionMode,
@@ -21,7 +18,7 @@ pub struct BasicWorkerBuilder {
}
impl BasicWorkerBuilder {
/// Create a new builder with only the URL (defaults to Regular worker type)
/// Create a new builder with only the URL
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
@@ -129,13 +126,10 @@ impl BasicWorkerBuilder {
/// Builder for creating DPAwareWorker instances with fluent API
pub struct DPAwareWorkerBuilder {
// Required fields
base_url: String,
api_key: Option<String>,
dp_rank: usize,
dp_size: usize,
// Optional fields with defaults
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
@@ -145,7 +139,7 @@ pub struct DPAwareWorkerBuilder {
}
impl DPAwareWorkerBuilder {
/// Create a new DP-aware worker builder (defaults to Regular worker type)
/// Create a new DP-aware worker builder
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self {
base_url: base_url.into(),
@@ -232,10 +226,7 @@ impl DPAwareWorkerBuilder {
/// Build the DPAwareWorker instance
pub fn build(self) -> DPAwareWorker {
// Create URL with DP rank suffix for identification
let worker_url = format!("{}@{}", self.base_url, self.dp_rank);
// Use BasicWorkerBuilder to create a properly configured base worker
let mut builder = BasicWorkerBuilder::new(worker_url)
.worker_type(self.worker_type)
.connection_mode(self.connection_mode)
@@ -243,18 +234,14 @@ impl DPAwareWorkerBuilder {
.health_config(self.health_config)
.circuit_breaker_config(self.circuit_breaker_config);
// Add gRPC client if provided
if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client);
}
// Add API key if provided
if let Some(api_key) = self.api_key {
builder = builder.api_key(api_key);
}
let base_worker = builder.build();
// Create the DPAwareWorker with the configured base worker
DPAwareWorker::with_base_worker(base_worker, self.base_url, self.dp_rank, self.dp_size)
}
}
@@ -267,7 +254,6 @@ mod tests {
#[test]
fn test_basic_worker_builder_minimal() {
// Using new API - defaults to Regular type
let worker = BasicWorkerBuilder::new("http://localhost:8080").build();
assert_eq!(worker.url(), "http://localhost:8080");
@@ -278,7 +264,6 @@ mod tests {
#[test]
fn test_basic_worker_builder_with_type() {
// Test setting worker type explicitly
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Decode)
.build();
@@ -332,7 +317,6 @@ mod tests {
ConnectionMode::Grpc { port: Some(50051) }
);
assert_eq!(worker.metadata().labels, labels);
// Can't directly compare HealthConfig without PartialEq, so check individual fields
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
@@ -375,13 +359,11 @@ mod tests {
#[test]
fn test_dp_aware_worker_builder_minimal() {
// Using new API - defaults to Regular type
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 2, 8).build();
assert_eq!(worker.url(), "http://localhost:8080@2");
assert_eq!(worker.dp_rank(), Some(2));
assert_eq!(worker.dp_size(), Some(8));
// Note: base_url is a private field, we can only test through the url() method
assert_eq!(worker.worker_type(), WorkerType::Regular);
}
@@ -412,7 +394,6 @@ mod tests {
assert_eq!(worker.dp_rank(), Some(3));
assert_eq!(worker.dp_size(), Some(16));
assert_eq!(worker.metadata().labels, labels);
// Can't directly compare HealthConfig without PartialEq, so check individual fields
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
@@ -437,7 +418,6 @@ mod tests {
#[test]
fn test_dp_aware_worker_with_grpc() {
// Test that DPAwareWorkerBuilder can set a gRPC client
let worker = DPAwareWorkerBuilder::new("grpc://cluster.local", 1, 4)
.worker_type(WorkerType::Decode)
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
@@ -456,9 +436,5 @@ mod tests {
worker.metadata().labels.get("transport"),
Some(&"grpc".to_string())
);
// Note: We can't directly test the grpc_client as it's private,
// but the fact that the worker builds successfully with grpc connection mode
// validates that the configuration is properly passed through
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,6 @@ impl Default for WorkerId {
}
}
/// Type alias for the model index to reduce complexity
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
/// Worker registry with model-based indexing
@@ -54,8 +53,7 @@ pub struct WorkerRegistry {
/// Workers indexed by connection mode
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
/// URL to worker ID mapping (for backward compatibility)
/// URL to worker ID mapping
url_to_id: Arc<DashMap<String, WorkerId>>,
}