[router] refactor router and worker management 3/n (#10727)
This commit is contained in:
@@ -83,14 +83,13 @@ impl CircuitBreaker {
|
||||
|
||||
/// Check if a request can be executed
|
||||
pub fn can_execute(&self) -> bool {
|
||||
// First check if we need to transition from Open to HalfOpen
|
||||
self.check_and_update_state();
|
||||
|
||||
let state = *self.state.read().unwrap();
|
||||
match state {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => false,
|
||||
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
|
||||
CircuitState::HalfOpen => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,22 +113,17 @@ impl CircuitBreaker {
|
||||
self.total_successes.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
match current_state {
|
||||
CircuitState::HalfOpen => {
|
||||
// Check if we've reached the success threshold to close the circuit
|
||||
if successes >= self.config.success_threshold {
|
||||
self.transition_to(CircuitState::Closed);
|
||||
}
|
||||
}
|
||||
CircuitState::Closed => {
|
||||
// Already closed, nothing to do
|
||||
}
|
||||
CircuitState::Closed => {}
|
||||
CircuitState::Open => {
|
||||
// Shouldn't happen, but if it does, stay open
|
||||
tracing::warn!("Success recorded while circuit is open");
|
||||
}
|
||||
}
|
||||
@@ -140,9 +134,7 @@ impl CircuitBreaker {
|
||||
self.total_failures.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
// Update last failure time
|
||||
{
|
||||
let mut last_failure = self.last_failure_time.write().unwrap();
|
||||
*last_failure = Some(Instant::now());
|
||||
@@ -152,18 +144,14 @@ impl CircuitBreaker {
|
||||
|
||||
match current_state {
|
||||
CircuitState::Closed => {
|
||||
// Check if we've reached the failure threshold to open the circuit
|
||||
if failures >= self.config.failure_threshold {
|
||||
self.transition_to(CircuitState::Open);
|
||||
}
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
// Single failure in half-open state reopens the circuit
|
||||
self.transition_to(CircuitState::Open);
|
||||
}
|
||||
CircuitState::Open => {
|
||||
// Already open, nothing to do
|
||||
}
|
||||
CircuitState::Open => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +160,6 @@ impl CircuitBreaker {
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
if current_state == CircuitState::Open {
|
||||
// Check if timeout has expired
|
||||
let last_change = *self.last_state_change.read().unwrap();
|
||||
if last_change.elapsed() >= self.config.timeout_duration {
|
||||
self.transition_to(CircuitState::HalfOpen);
|
||||
@@ -188,11 +175,9 @@ impl CircuitBreaker {
|
||||
if old_state != new_state {
|
||||
*state = new_state;
|
||||
|
||||
// Update last state change time
|
||||
let mut last_change = self.last_state_change.write().unwrap();
|
||||
*last_change = Instant::now();
|
||||
|
||||
// Reset counters based on transition
|
||||
match new_state {
|
||||
CircuitState::Closed => {
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
@@ -218,7 +203,6 @@ impl CircuitBreaker {
|
||||
CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
info!("Circuit breaker state transition: {} -> {}", from, to);
|
||||
// Transition metrics are recorded at the worker level where the worker label is known
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,7 +517,6 @@ mod tests {
|
||||
let cb = Arc::new(CircuitBreaker::new());
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn threads that record failures
|
||||
for _ in 0..10 {
|
||||
let cb_clone = Arc::clone(&cb);
|
||||
let handle = thread::spawn(move || {
|
||||
@@ -544,12 +527,10 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Should have recorded 1000 failures
|
||||
assert_eq!(cb.total_failures(), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,6 @@ mod tests {
|
||||
let error = WorkerError::WorkerNotFound {
|
||||
url: "http://test".to_string(),
|
||||
};
|
||||
// Verify it implements Error trait
|
||||
let _: &dyn Error = &error;
|
||||
assert!(error.source().is_none());
|
||||
}
|
||||
@@ -135,11 +134,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_worker_result_type_alias() {
|
||||
// Test Ok variant
|
||||
let result: WorkerResult<i32> = Ok(42);
|
||||
assert!(matches!(result, Ok(42)));
|
||||
|
||||
// Test Err variant
|
||||
let error = WorkerError::WorkerNotFound {
|
||||
url: "test".to_string(),
|
||||
};
|
||||
@@ -149,7 +146,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_empty_url_handling() {
|
||||
// Test empty URLs in error variants
|
||||
let error1 = WorkerError::HealthCheckFailed {
|
||||
url: "".to_string(),
|
||||
reason: "No connection".to_string(),
|
||||
@@ -173,7 +169,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_special_characters_in_messages() {
|
||||
// Test with special characters
|
||||
let error = WorkerError::InvalidConfiguration {
|
||||
message: "Invalid JSON: {\"error\": \"test\"}".to_string(),
|
||||
};
|
||||
@@ -182,7 +177,6 @@ mod tests {
|
||||
"Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}"
|
||||
);
|
||||
|
||||
// Test with unicode
|
||||
let error2 = WorkerError::HealthCheckFailed {
|
||||
url: "http://测试:8080".to_string(),
|
||||
reason: "连接被拒绝".to_string(),
|
||||
@@ -207,10 +201,8 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
// Mock reqwest error for testing conversion
|
||||
#[test]
|
||||
fn test_reqwest_error_conversion() {
|
||||
// Test that NetworkError is the correct variant
|
||||
let network_error = WorkerError::NetworkError {
|
||||
url: "http://example.com".to_string(),
|
||||
error: "connection timeout".to_string(),
|
||||
@@ -227,8 +219,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_error_equality() {
|
||||
// WorkerError doesn't implement PartialEq, but we can test that
|
||||
// the same error construction produces the same display output
|
||||
let error1 = WorkerError::WorkerNotFound {
|
||||
url: "http://test".to_string(),
|
||||
};
|
||||
|
||||
@@ -12,9 +12,9 @@ pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
pub mod worker_builder;
|
||||
pub mod worker_manager;
|
||||
pub mod worker_registry;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
pub use circuit_breaker::{
|
||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||
};
|
||||
@@ -25,4 +25,5 @@ pub use worker::{
|
||||
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
pub use worker_manager::{DpInfo, ServerInfo, WorkerManager};
|
||||
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
||||
|
||||
@@ -25,14 +25,12 @@ pub struct BackoffCalculator;
|
||||
impl BackoffCalculator {
|
||||
/// Calculate backoff delay for a given attempt index (0-based).
|
||||
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
|
||||
// Base exponential backoff
|
||||
let pow = config.backoff_multiplier.powi(attempt as i32);
|
||||
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
|
||||
if delay_ms > config.max_backoff_ms {
|
||||
delay_ms = config.max_backoff_ms;
|
||||
}
|
||||
|
||||
// Apply jitter in range [-j, +j]
|
||||
let jitter = config.jitter_factor.clamp(0.0, 1.0);
|
||||
if jitter > 0.0 {
|
||||
let mut rng = rand::rng();
|
||||
@@ -77,14 +75,12 @@ impl RetryExecutor {
|
||||
match operation(attempt).await {
|
||||
Ok(val) => return Ok(val),
|
||||
Err(_) => {
|
||||
// Use the number of failures so far (0-indexed) to compute delay,
|
||||
// so the first retry uses `initial_backoff_ms`.
|
||||
let is_last = attempt + 1 >= max;
|
||||
if is_last {
|
||||
return Err(RetryError::MaxRetriesExceeded);
|
||||
}
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
attempt += 1; // advance to the next attempt after computing delay
|
||||
attempt += 1;
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
@@ -144,14 +140,11 @@ impl RetryExecutor {
|
||||
}
|
||||
|
||||
if is_last {
|
||||
// Exhausted retries
|
||||
on_exhausted();
|
||||
return response;
|
||||
}
|
||||
|
||||
// Backoff before next attempt
|
||||
let next_attempt = attempt + 1;
|
||||
// Compute delay based on the number of failures so far (0-indexed)
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
debug!(
|
||||
attempt = attempt,
|
||||
@@ -194,22 +187,18 @@ mod tests {
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.0,
|
||||
};
|
||||
// attempt=0 => 100ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 0),
|
||||
Duration::from_millis(100)
|
||||
);
|
||||
// attempt=1 => 200ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 1),
|
||||
Duration::from_millis(200)
|
||||
);
|
||||
// attempt=2 => 400ms -> capped to 250ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 2),
|
||||
Duration::from_millis(250)
|
||||
);
|
||||
// large attempt still capped
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 10),
|
||||
Duration::from_millis(250)
|
||||
@@ -225,7 +214,6 @@ mod tests {
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.5,
|
||||
};
|
||||
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
|
||||
let base = 400.0;
|
||||
for _ in 0..50 {
|
||||
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
|
||||
@@ -261,7 +249,7 @@ mod tests {
|
||||
|
||||
assert!(res.is_ok());
|
||||
assert_eq!(res.unwrap(), 42);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -309,7 +297,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
},
|
||||
|res, _attempt| !res.status().is_success(), // retry until success
|
||||
|res, _attempt| !res.status().is_success(),
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
@@ -326,7 +314,7 @@ mod tests {
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3);
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
@@ -347,7 +335,7 @@ mod tests {
|
||||
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| false, // never retry
|
||||
|_res, _attempt| false,
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
@@ -385,7 +373,7 @@ mod tests {
|
||||
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| true, // keep retrying
|
||||
|_res, _attempt| true,
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
|
||||
@@ -32,16 +32,11 @@ impl TokenBucket {
|
||||
let capacity = capacity as f64;
|
||||
let refill_rate = refill_rate as f64;
|
||||
|
||||
// Ensure refill_rate is not zero to prevent division by zero
|
||||
let refill_rate = if refill_rate > 0.0 {
|
||||
refill_rate
|
||||
} else {
|
||||
1.0 // Default to 1 token per second if zero
|
||||
};
|
||||
let refill_rate = if refill_rate > 0.0 { refill_rate } else { 1.0 };
|
||||
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(TokenBucketInner {
|
||||
tokens: capacity, // Start full
|
||||
tokens: capacity,
|
||||
last_refill: Instant::now(),
|
||||
})),
|
||||
notify: Arc::new(Notify::new()),
|
||||
@@ -54,7 +49,6 @@ impl TokenBucket {
|
||||
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
@@ -82,12 +76,10 @@ impl TokenBucket {
|
||||
|
||||
/// Acquire tokens, waiting if necessary
|
||||
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
|
||||
// First try to acquire immediately
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Calculate wait time
|
||||
let wait_time = {
|
||||
let inner = self.inner.lock().await;
|
||||
let tokens_needed = tokens - inner.tokens;
|
||||
@@ -100,15 +92,12 @@ impl TokenBucket {
|
||||
wait_time, tokens
|
||||
);
|
||||
|
||||
// Wait for tokens to be available
|
||||
tokio::time::timeout(wait_time, async {
|
||||
loop {
|
||||
// Check if we can acquire now
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for notification or small interval
|
||||
tokio::select! {
|
||||
_ = self.notify.notified() => {},
|
||||
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
|
||||
@@ -144,7 +133,6 @@ impl TokenBucket {
|
||||
pub async fn available_tokens(&self) -> f64 {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill before checking
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
@@ -162,33 +150,26 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_basic() {
|
||||
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
|
||||
let bucket = TokenBucket::new(10, 5);
|
||||
|
||||
// Should succeed - bucket starts full
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
|
||||
// Should fail - no tokens left
|
||||
assert!(bucket.try_acquire(1.0).await.is_err());
|
||||
|
||||
// Wait for refill
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
|
||||
// Should have ~1.5 tokens now
|
||||
assert!(bucket.try_acquire(1.0).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_refill() {
|
||||
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
|
||||
let bucket = TokenBucket::new(10, 10);
|
||||
|
||||
// Use all tokens
|
||||
assert!(bucket.try_acquire(10.0).await.is_ok());
|
||||
|
||||
// Wait for partial refill
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Should have ~5 tokens
|
||||
let available = bucket.available_tokens().await;
|
||||
assert!((4.0..=6.0).contains(&available));
|
||||
}
|
||||
|
||||
@@ -11,10 +11,9 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// Shared HTTP client for worker operations (health checks, server info, etc.)
|
||||
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create worker HTTP client")
|
||||
});
|
||||
@@ -43,7 +42,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
|
||||
/// Synchronous health check wrapper (for compatibility)
|
||||
fn check_health(&self) -> WorkerResult<()> {
|
||||
// Use a small runtime for synchronous contexts
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
@@ -64,10 +62,7 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
fn decrement_load(&self);
|
||||
|
||||
/// Reset the load counter to 0 (for sync/recovery)
|
||||
fn reset_load(&self) {
|
||||
// Default implementation - does nothing
|
||||
// Workers that track load should override this
|
||||
}
|
||||
fn reset_load(&self) {}
|
||||
|
||||
/// Get the number of processed requests
|
||||
fn processed_requests(&self) -> usize;
|
||||
@@ -88,11 +83,9 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
|
||||
/// Record the outcome of a request to this worker
|
||||
fn record_outcome(&self, success: bool) {
|
||||
// Record outcome-level metric with worker label
|
||||
let outcome_str = if success { "success" } else { "failure" };
|
||||
RouterMetrics::record_cb_outcome(self.url(), outcome_str);
|
||||
|
||||
// Record into circuit breaker and infer state change for metrics
|
||||
let before = self.circuit_breaker().state();
|
||||
self.circuit_breaker().record_outcome(success);
|
||||
let after = self.circuit_breaker().state();
|
||||
@@ -119,8 +112,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
RouterMetrics::set_cb_state(self.url(), state_code);
|
||||
}
|
||||
|
||||
// === DP-aware methods ===
|
||||
|
||||
/// Check if this worker is DP-aware
|
||||
fn is_dp_aware(&self) -> bool {
|
||||
false
|
||||
@@ -156,8 +147,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
true
|
||||
}
|
||||
|
||||
// === Multi-router support ===
|
||||
|
||||
// TODO: - Enhanced Worker Discovery
|
||||
// The Worker trait should handle async discovery of metadata from the worker itself
|
||||
// rather than having service discovery or other components query /get_server_info.
|
||||
@@ -356,14 +345,12 @@ impl fmt::Debug for BasicWorker {
|
||||
impl BasicWorker {
|
||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||
if self.url().contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let parts: Vec<&str> = self.url().split('@').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(WorkerError::InvalidUrl {
|
||||
url: self.url().to_string(),
|
||||
});
|
||||
}
|
||||
// Ensure the second part (the dp_rank) can be parsed as an integer
|
||||
match parts[1].parse::<usize>() {
|
||||
Ok(_) => Ok(parts[0]),
|
||||
Err(_) => Err(WorkerError::InvalidUrl {
|
||||
@@ -408,19 +395,22 @@ impl Worker for BasicWorker {
|
||||
|
||||
let health_result = match &self.metadata.connection_mode {
|
||||
ConnectionMode::Http => {
|
||||
// Perform HTTP health check
|
||||
let url = self.normalised_url()?;
|
||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
|
||||
// Use the shared client with a custom timeout for this request
|
||||
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
||||
let mut request = WORKER_CLIENT.get(&health_url).timeout(timeout);
|
||||
|
||||
if let Some(ref api_key) = self.metadata.api_key {
|
||||
request = request.header("Authorization", format!("Bearer {}", api_key));
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
ConnectionMode::Grpc { .. } => {
|
||||
// Perform gRPC health check
|
||||
if let Some(grpc_client) = &self.grpc_client {
|
||||
let mut client = grpc_client.lock().await;
|
||||
match client.health_check().await {
|
||||
@@ -449,11 +439,9 @@ impl Worker for BasicWorker {
|
||||
};
|
||||
|
||||
if health_result {
|
||||
// Health check succeeded
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
|
||||
// Mark healthy if we've reached the success threshold
|
||||
if !self.is_healthy()
|
||||
&& successes >= self.metadata.health_config.success_threshold as usize
|
||||
{
|
||||
@@ -462,11 +450,9 @@ impl Worker for BasicWorker {
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
// Health check failed
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
|
||||
// Mark unhealthy if we've reached the failure threshold
|
||||
if self.is_healthy()
|
||||
&& failures >= self.metadata.health_config.failure_threshold as usize
|
||||
{
|
||||
@@ -576,7 +562,6 @@ impl Worker for DPAwareWorker {
|
||||
}
|
||||
|
||||
async fn check_health_async(&self) -> WorkerResult<()> {
|
||||
// Delegate to the base worker's health check logic
|
||||
self.base_worker.check_health_async().await
|
||||
}
|
||||
|
||||
@@ -612,8 +597,6 @@ impl Worker for DPAwareWorker {
|
||||
self.base_worker.circuit_breaker()
|
||||
}
|
||||
|
||||
// DP-aware specific implementations
|
||||
|
||||
fn is_dp_aware(&self) -> bool {
|
||||
true
|
||||
}
|
||||
@@ -631,7 +614,6 @@ impl Worker for DPAwareWorker {
|
||||
}
|
||||
|
||||
async fn prepare_request(&self, mut req: serde_json::Value) -> WorkerResult<serde_json::Value> {
|
||||
// Inject data_parallel_rank into the request
|
||||
if let Some(map) = req.as_object_mut() {
|
||||
map.insert(
|
||||
"data_parallel_rank".to_string(),
|
||||
@@ -646,7 +628,6 @@ impl Worker for DPAwareWorker {
|
||||
}
|
||||
|
||||
fn endpoint_url(&self, route: &str) -> String {
|
||||
// Use base URL for actual requests
|
||||
format!("{}{}", self.base_url, route)
|
||||
}
|
||||
}
|
||||
@@ -670,53 +651,52 @@ impl WorkerFactory {
|
||||
}
|
||||
Box::new(builder.build())
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
|
||||
if let Some(key) = &api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
/// Static health validation before creating a worker
|
||||
/// This replaces wait_for_worker_health in handlers
|
||||
pub async fn validate_health(url: &str, timeout_secs: u64) -> WorkerResult<()> {
|
||||
use std::time::Instant;
|
||||
|
||||
let response = req_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| WorkerError::NetworkError {
|
||||
url: url.to_string(),
|
||||
error: e.to_string(),
|
||||
})?;
|
||||
let start_time = Instant::now();
|
||||
let timeout = std::time::Duration::from_secs(timeout_secs);
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(WorkerError::NetworkError {
|
||||
url: url.to_string(),
|
||||
error: format!("Server returned: {}", response.status()),
|
||||
});
|
||||
}
|
||||
|
||||
let info: serde_json::Value =
|
||||
response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| WorkerError::NetworkError {
|
||||
loop {
|
||||
if start_time.elapsed() > timeout {
|
||||
return Err(WorkerError::HealthCheckFailed {
|
||||
url: url.to_string(),
|
||||
error: format!("Failed to parse JSON: {}", e),
|
||||
})?;
|
||||
reason: format!(
|
||||
"Timeout {}s waiting for worker to become healthy",
|
||||
timeout_secs
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let dp_size = info
|
||||
.get("dp_size")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| WorkerError::InvalidConfiguration {
|
||||
message: "dp_size not found in server info".to_string(),
|
||||
})?;
|
||||
// Note: This static function doesn't have access to worker's API key
|
||||
// API key authentication is handled in the worker instance's check_health_async method
|
||||
match WORKER_CLIENT
|
||||
.get(format!("{}/health", url))
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) if res.status().is_success() => {
|
||||
tracing::info!("Worker {} is healthy", url);
|
||||
return Ok(());
|
||||
}
|
||||
Ok(res) => {
|
||||
tracing::warn!(
|
||||
"Worker {} health check failed with status: {}",
|
||||
url,
|
||||
res.status()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to contact worker {}: {}", url, e);
|
||||
}
|
||||
}
|
||||
|
||||
if dp_size > usize::MAX as u64 {
|
||||
return Err(WorkerError::InvalidConfiguration {
|
||||
message: format!("dp_size is too large: {}", dp_size),
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
|
||||
Ok(dp_size as usize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -893,7 +873,6 @@ mod tests {
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
// Test WorkerType
|
||||
#[test]
|
||||
fn test_worker_type_display() {
|
||||
assert_eq!(WorkerType::Regular.to_string(), "Regular");
|
||||
@@ -945,7 +924,6 @@ mod tests {
|
||||
assert_eq!(original, cloned);
|
||||
}
|
||||
|
||||
// Test HealthConfig
|
||||
#[test]
|
||||
fn test_health_config_default() {
|
||||
let config = HealthConfig::default();
|
||||
@@ -972,13 +950,11 @@ mod tests {
|
||||
assert_eq!(config.success_threshold, 3);
|
||||
}
|
||||
|
||||
// Test BasicWorker
|
||||
#[test]
|
||||
fn test_basic_worker_creation() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://test:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
@@ -1016,7 +992,6 @@ mod tests {
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.health_config(custom_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
|
||||
@@ -1024,13 +999,11 @@ mod tests {
|
||||
assert_eq!(worker.metadata().health_config.endpoint, "/custom-health");
|
||||
}
|
||||
|
||||
// Test Worker trait implementation
|
||||
#[test]
|
||||
fn test_worker_url() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://worker1:8080");
|
||||
}
|
||||
@@ -1040,7 +1013,6 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let regular = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(regular.worker_type(), WorkerType::Regular);
|
||||
|
||||
@@ -1048,7 +1020,6 @@ mod tests {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(
|
||||
prefill.worker_type(),
|
||||
@@ -1059,7 +1030,6 @@ mod tests {
|
||||
|
||||
let decode = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(decode.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1071,14 +1041,11 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Initial state is healthy
|
||||
assert!(worker.is_healthy());
|
||||
|
||||
// Set unhealthy
|
||||
worker.set_healthy(false);
|
||||
assert!(!worker.is_healthy());
|
||||
|
||||
// Set healthy again
|
||||
worker.set_healthy(true);
|
||||
assert!(worker.is_healthy());
|
||||
}
|
||||
@@ -1088,31 +1055,24 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
// Initial load is 0
|
||||
assert_eq!(worker.load(), 0);
|
||||
|
||||
// Increment once
|
||||
worker.increment_load();
|
||||
assert_eq!(worker.load(), 1);
|
||||
|
||||
// Increment twice more
|
||||
worker.increment_load();
|
||||
worker.increment_load();
|
||||
assert_eq!(worker.load(), 3);
|
||||
|
||||
// Decrement once
|
||||
worker.decrement_load();
|
||||
assert_eq!(worker.load(), 2);
|
||||
|
||||
// Decrement to 0
|
||||
worker.decrement_load();
|
||||
worker.decrement_load();
|
||||
assert_eq!(worker.load(), 0);
|
||||
|
||||
// Decrement below 0 should stay at 0
|
||||
worker.decrement_load();
|
||||
assert_eq!(worker.load(), 0);
|
||||
}
|
||||
@@ -1124,17 +1084,14 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Initial count is 0
|
||||
assert_eq!(worker.processed_requests(), 0);
|
||||
|
||||
// Increment multiple times
|
||||
for i in 1..=100 {
|
||||
worker.increment_processed();
|
||||
assert_eq!(worker.processed_requests(), i);
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent operations
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_load_increments() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
@@ -1146,7 +1103,6 @@ mod tests {
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 100 tasks incrementing load
|
||||
for _ in 0..100 {
|
||||
let worker_clone = Arc::clone(&worker);
|
||||
let handle = tokio::spawn(async move {
|
||||
@@ -1155,12 +1111,10 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// Final count should be 100
|
||||
assert_eq!(worker.load(), 100);
|
||||
}
|
||||
|
||||
@@ -1173,7 +1127,6 @@ mod tests {
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Set initial load to 100
|
||||
for _ in 0..100 {
|
||||
worker.increment_load();
|
||||
}
|
||||
@@ -1181,7 +1134,6 @@ mod tests {
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 100 tasks decrementing load
|
||||
for _ in 0..100 {
|
||||
let worker_clone = Arc::clone(&worker);
|
||||
let handle = tokio::spawn(async move {
|
||||
@@ -1190,12 +1142,10 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// Final count should be 0
|
||||
assert_eq!(worker.load(), 0);
|
||||
}
|
||||
|
||||
@@ -1210,7 +1160,6 @@ mod tests {
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn threads randomly setting health status
|
||||
for i in 0..100 {
|
||||
let worker_clone = Arc::clone(&worker);
|
||||
let handle = tokio::spawn(async move {
|
||||
@@ -1220,13 +1169,11 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Test WorkerFactory
|
||||
#[test]
|
||||
fn test_create_regular_worker() {
|
||||
let worker: Box<dyn Worker> = Box::new(
|
||||
@@ -1240,7 +1187,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_prefill_worker() {
|
||||
// With bootstrap port
|
||||
let worker1: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -1256,7 +1202,6 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
// Without bootstrap port
|
||||
let worker2: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -1283,7 +1228,6 @@ mod tests {
|
||||
assert_eq!(worker.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
|
||||
// Test WorkerLoadGuard
|
||||
#[test]
|
||||
fn test_load_guard_single_worker() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
@@ -1297,7 +1241,6 @@ mod tests {
|
||||
assert_eq!(worker.load(), 1);
|
||||
}
|
||||
|
||||
// Guard dropped, load decremented
|
||||
assert_eq!(worker.load(), 0);
|
||||
}
|
||||
|
||||
@@ -1325,13 +1268,11 @@ mod tests {
|
||||
|
||||
{
|
||||
let _guard = WorkerLoadGuard::new_multi(worker_refs);
|
||||
// All loads incremented
|
||||
assert_eq!(workers[0].load(), 1);
|
||||
assert_eq!(workers[1].load(), 1);
|
||||
assert_eq!(workers[2].load(), 1);
|
||||
}
|
||||
|
||||
// All loads decremented
|
||||
assert_eq!(workers[0].load(), 0);
|
||||
assert_eq!(workers[1].load(), 0);
|
||||
assert_eq!(workers[2].load(), 0);
|
||||
@@ -1347,29 +1288,21 @@ mod tests {
|
||||
);
|
||||
assert_eq!(worker.load(), 0);
|
||||
|
||||
// Clone for use inside catch_unwind
|
||||
let worker_clone = Arc::clone(&worker);
|
||||
|
||||
// Use AssertUnwindSafe wrapper for the test
|
||||
// This is safe because we're only testing the load counter behavior,
|
||||
// not the grpc_client which is None for HTTP workers
|
||||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
// This will panic, but the guard should still clean up
|
||||
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
|
||||
assert_eq!(worker_clone.load(), 1);
|
||||
panic!("Test panic");
|
||||
}));
|
||||
|
||||
// Verify panic occurred
|
||||
assert!(result.is_err());
|
||||
|
||||
// Load should be decremented even after panic
|
||||
assert_eq!(worker.load(), 0);
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
#[test]
|
||||
fn test_urls_to_workers() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
@@ -1400,23 +1333,17 @@ mod tests {
|
||||
assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]);
|
||||
}
|
||||
|
||||
// Test synchronous health check wrapper
|
||||
#[test]
|
||||
fn test_check_health_sync_wrapper() {
|
||||
// We can't easily test the actual HTTP call without mocking,
|
||||
// but we can verify the sync wrapper works
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// This will fail because there's no server at this URL,
|
||||
// but it tests that the sync wrapper doesn't panic
|
||||
let result = worker.check_health();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// Performance test for load counter
|
||||
#[test]
|
||||
fn test_load_counter_performance() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
@@ -1436,12 +1363,9 @@ mod tests {
|
||||
let ops_per_sec = iterations as f64 / duration.as_secs_f64();
|
||||
println!("Load counter operations per second: {:.0}", ops_per_sec);
|
||||
|
||||
// Should be well over 1M ops/sec
|
||||
assert!(ops_per_sec > 1_000_000.0);
|
||||
}
|
||||
|
||||
// ===== Tests for DPAwareWorker =====
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_creation() {
|
||||
let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4)
|
||||
@@ -1562,8 +1486,6 @@ mod tests {
|
||||
assert_eq!(dp_worker.processed_requests(), 1);
|
||||
}
|
||||
|
||||
// ===== Tests for WorkerFactory async methods =====
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_factory_create_dp_aware() {
|
||||
let worker = WorkerFactory::create_dp_aware(
|
||||
@@ -1610,26 +1532,21 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Initial state should be available
|
||||
assert!(worker.is_available());
|
||||
assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
|
||||
|
||||
// Record some failures
|
||||
worker.record_outcome(false);
|
||||
worker.record_outcome(false);
|
||||
|
||||
// Still available (default threshold is 5)
|
||||
assert!(worker.is_available());
|
||||
|
||||
// Record more failures to open circuit
|
||||
worker.record_outcome(false);
|
||||
worker.record_outcome(false);
|
||||
worker.record_outcome(false);
|
||||
|
||||
// Circuit should be open, worker not available
|
||||
assert!(!worker.is_available());
|
||||
assert!(worker.is_healthy()); // Still healthy
|
||||
assert!(!worker.circuit_breaker().can_execute()); // But circuit is open
|
||||
assert!(worker.is_healthy());
|
||||
assert!(!worker.circuit_breaker().can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1647,20 +1564,16 @@ mod tests {
|
||||
.circuit_breaker_config(config)
|
||||
.build();
|
||||
|
||||
// Should open after 2 failures
|
||||
worker.record_outcome(false);
|
||||
assert!(worker.is_available());
|
||||
worker.record_outcome(false);
|
||||
assert!(!worker.is_available());
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Should be half-open
|
||||
assert!(worker.is_available());
|
||||
assert_eq!(worker.circuit_breaker().state(), CircuitState::HalfOpen);
|
||||
|
||||
// Success should close it
|
||||
worker.record_outcome(true);
|
||||
assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
|
||||
}
|
||||
@@ -1671,24 +1584,18 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Should have circuit breaker
|
||||
assert!(dp_worker.is_available());
|
||||
|
||||
// Record failures
|
||||
for _ in 0..5 {
|
||||
dp_worker.record_outcome(false);
|
||||
}
|
||||
|
||||
// Should not be available
|
||||
assert!(!dp_worker.is_available());
|
||||
assert_eq!(dp_worker.circuit_breaker().state(), CircuitState::Open);
|
||||
}
|
||||
|
||||
// ===== Integration tests =====
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mixed_worker_types() {
|
||||
// Create a mix of worker types
|
||||
let regular: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
@@ -1739,22 +1646,19 @@ mod tests {
|
||||
dp_aware_decode,
|
||||
];
|
||||
|
||||
// Test that they all implement Worker trait properly
|
||||
for worker in &workers {
|
||||
assert!(worker.is_healthy());
|
||||
assert_eq!(worker.load(), 0);
|
||||
assert_eq!(worker.processed_requests(), 0);
|
||||
}
|
||||
|
||||
// Test specific behaviors
|
||||
assert!(!workers[0].is_dp_aware()); // regular
|
||||
assert!(!workers[1].is_dp_aware()); // prefill
|
||||
assert!(!workers[2].is_dp_aware()); // decode
|
||||
assert!(workers[3].is_dp_aware()); // dp_aware_regular
|
||||
assert!(workers[4].is_dp_aware()); // dp_aware_prefill
|
||||
assert!(workers[5].is_dp_aware()); // dp_aware_decode
|
||||
assert!(!workers[0].is_dp_aware());
|
||||
assert!(!workers[1].is_dp_aware());
|
||||
assert!(!workers[2].is_dp_aware());
|
||||
assert!(workers[3].is_dp_aware());
|
||||
assert!(workers[4].is_dp_aware());
|
||||
assert!(workers[5].is_dp_aware());
|
||||
|
||||
// Test worker types
|
||||
assert_eq!(workers[0].worker_type(), WorkerType::Regular);
|
||||
assert_eq!(
|
||||
workers[1].worker_type(),
|
||||
|
||||
@@ -7,10 +7,7 @@ use std::collections::HashMap;
|
||||
|
||||
/// Builder for creating BasicWorker instances with fluent API
|
||||
pub struct BasicWorkerBuilder {
|
||||
// Required fields
|
||||
url: String,
|
||||
|
||||
// Optional fields with defaults
|
||||
api_key: Option<String>,
|
||||
worker_type: WorkerType,
|
||||
connection_mode: ConnectionMode,
|
||||
@@ -21,7 +18,7 @@ pub struct BasicWorkerBuilder {
|
||||
}
|
||||
|
||||
impl BasicWorkerBuilder {
|
||||
/// Create a new builder with only the URL (defaults to Regular worker type)
|
||||
/// Create a new builder with only the URL
|
||||
pub fn new(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
url: url.into(),
|
||||
@@ -129,13 +126,10 @@ impl BasicWorkerBuilder {
|
||||
|
||||
/// Builder for creating DPAwareWorker instances with fluent API
|
||||
pub struct DPAwareWorkerBuilder {
|
||||
// Required fields
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
|
||||
// Optional fields with defaults
|
||||
worker_type: WorkerType,
|
||||
connection_mode: ConnectionMode,
|
||||
labels: HashMap<String, String>,
|
||||
@@ -145,7 +139,7 @@ pub struct DPAwareWorkerBuilder {
|
||||
}
|
||||
|
||||
impl DPAwareWorkerBuilder {
|
||||
/// Create a new DP-aware worker builder (defaults to Regular worker type)
|
||||
/// Create a new DP-aware worker builder
|
||||
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
@@ -232,10 +226,7 @@ impl DPAwareWorkerBuilder {
|
||||
|
||||
/// Build the DPAwareWorker instance
|
||||
pub fn build(self) -> DPAwareWorker {
|
||||
// Create URL with DP rank suffix for identification
|
||||
let worker_url = format!("{}@{}", self.base_url, self.dp_rank);
|
||||
|
||||
// Use BasicWorkerBuilder to create a properly configured base worker
|
||||
let mut builder = BasicWorkerBuilder::new(worker_url)
|
||||
.worker_type(self.worker_type)
|
||||
.connection_mode(self.connection_mode)
|
||||
@@ -243,18 +234,14 @@ impl DPAwareWorkerBuilder {
|
||||
.health_config(self.health_config)
|
||||
.circuit_breaker_config(self.circuit_breaker_config);
|
||||
|
||||
// Add gRPC client if provided
|
||||
if let Some(client) = self.grpc_client {
|
||||
builder = builder.grpc_client(client);
|
||||
}
|
||||
// Add API key if provided
|
||||
if let Some(api_key) = self.api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
let base_worker = builder.build();
|
||||
|
||||
// Create the DPAwareWorker with the configured base worker
|
||||
DPAwareWorker::with_base_worker(base_worker, self.base_url, self.dp_rank, self.dp_size)
|
||||
}
|
||||
}
|
||||
@@ -267,7 +254,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_basic_worker_builder_minimal() {
|
||||
// Using new API - defaults to Regular type
|
||||
let worker = BasicWorkerBuilder::new("http://localhost:8080").build();
|
||||
|
||||
assert_eq!(worker.url(), "http://localhost:8080");
|
||||
@@ -278,7 +264,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_basic_worker_builder_with_type() {
|
||||
// Test setting worker type explicitly
|
||||
let worker = BasicWorkerBuilder::new("http://localhost:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build();
|
||||
@@ -332,7 +317,6 @@ mod tests {
|
||||
ConnectionMode::Grpc { port: Some(50051) }
|
||||
);
|
||||
assert_eq!(worker.metadata().labels, labels);
|
||||
// Can't directly compare HealthConfig without PartialEq, so check individual fields
|
||||
assert_eq!(
|
||||
worker.metadata().health_config.endpoint,
|
||||
health_config.endpoint
|
||||
@@ -375,13 +359,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_builder_minimal() {
|
||||
// Using new API - defaults to Regular type
|
||||
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 2, 8).build();
|
||||
|
||||
assert_eq!(worker.url(), "http://localhost:8080@2");
|
||||
assert_eq!(worker.dp_rank(), Some(2));
|
||||
assert_eq!(worker.dp_size(), Some(8));
|
||||
// Note: base_url is a private field, we can only test through the url() method
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
}
|
||||
|
||||
@@ -412,7 +394,6 @@ mod tests {
|
||||
assert_eq!(worker.dp_rank(), Some(3));
|
||||
assert_eq!(worker.dp_size(), Some(16));
|
||||
assert_eq!(worker.metadata().labels, labels);
|
||||
// Can't directly compare HealthConfig without PartialEq, so check individual fields
|
||||
assert_eq!(
|
||||
worker.metadata().health_config.endpoint,
|
||||
health_config.endpoint
|
||||
@@ -437,7 +418,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_dp_aware_worker_with_grpc() {
|
||||
// Test that DPAwareWorkerBuilder can set a gRPC client
|
||||
let worker = DPAwareWorkerBuilder::new("grpc://cluster.local", 1, 4)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
|
||||
@@ -456,9 +436,5 @@ mod tests {
|
||||
worker.metadata().labels.get("transport"),
|
||||
Some(&"grpc".to_string())
|
||||
);
|
||||
|
||||
// Note: We can't directly test the grpc_client as it's private,
|
||||
// but the fact that the worker builds successfully with grpc connection mode
|
||||
// validates that the configuration is properly passed through
|
||||
}
|
||||
}
|
||||
|
||||
1024
sgl-router/src/core/worker_manager.rs
Normal file
1024
sgl-router/src/core/worker_manager.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -34,7 +34,6 @@ impl Default for WorkerId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for the model index to reduce complexity
|
||||
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
|
||||
|
||||
/// Worker registry with model-based indexing
|
||||
@@ -54,8 +53,7 @@ pub struct WorkerRegistry {
|
||||
|
||||
/// Workers indexed by connection mode
|
||||
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
|
||||
|
||||
/// URL to worker ID mapping (for backward compatibility)
|
||||
/// URL to worker ID mapping
|
||||
url_to_id: Arc<DashMap<String, WorkerId>>,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user