Grpc client (#9939)

This commit is contained in:
Chang Su
2025-09-02 11:47:35 -07:00
committed by GitHub
parent 4d89389c4f
commit 11dcabc545
3 changed files with 99 additions and 21 deletions

View File

@@ -1,4 +1,5 @@
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use async_trait::async_trait;
use futures;
@@ -6,6 +7,7 @@ use serde_json;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex;
// Shared HTTP client for worker operations (health checks, server info, etc.)
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
@@ -249,7 +251,7 @@ pub struct WorkerMetadata {
}
/// Basic worker implementation
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct BasicWorker {
metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>,
@@ -258,6 +260,19 @@ pub struct BasicWorker {
consecutive_failures: Arc<AtomicUsize>,
consecutive_successes: Arc<AtomicUsize>,
circuit_breaker: CircuitBreaker,
/// Optional gRPC client for gRPC workers
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
}
impl fmt::Debug for BasicWorker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BasicWorker")
.field("metadata", &self.metadata)
.field("healthy", &self.healthy.load(Ordering::Relaxed))
.field("circuit_breaker", &self.circuit_breaker)
.field("has_grpc_client", &self.grpc_client.is_some())
.finish()
}
}
impl BasicWorker {
@@ -286,6 +301,7 @@ impl BasicWorker {
consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::new(),
grpc_client: None,
}
}
@@ -304,6 +320,12 @@ impl BasicWorker {
self
}
/// Set the gRPC client for gRPC workers
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(Arc::new(Mutex::new(client)));
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
@@ -352,15 +374,46 @@ impl Worker for BasicWorker {
async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration;
// Perform actual HTTP health check
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
let health_result = match &self.metadata.connection_mode {
ConnectionMode::Http => {
// Perform HTTP health check
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
// Use the shared client with a custom timeout for this request
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
ConnectionMode::Grpc { .. } => {
// Perform gRPC health check
if let Some(grpc_client) = &self.grpc_client {
let mut client = grpc_client.lock().await;
match client.health_check().await {
Ok(response) => {
tracing::debug!(
"gRPC health check succeeded for {}: healthy={}",
self.metadata.url,
response.healthy
);
response.healthy
}
Err(e) => {
tracing::warn!(
"gRPC health check RPC failed for {}: {:?}",
self.metadata.url,
e
);
false
}
}
} else {
tracing::error!("No gRPC client available for worker {}", self.metadata.url);
false
}
}
};
if health_result {
@@ -390,7 +443,7 @@ impl Worker for BasicWorker {
}
Err(WorkerError::HealthCheckFailed {
url: url.to_string(),
url: self.metadata.url.clone(),
reason: format!("Health check failed (consecutive failures: {})", failures),
})
}
@@ -1491,12 +1544,17 @@ mod tests {
// Clone for use inside catch_unwind
let worker_clone = Arc::clone(&worker);
// Use AssertUnwindSafe wrapper for the test
// This is safe because we're only testing the load counter behavior,
// not the grpc_client which is None for HTTP workers
use std::panic::AssertUnwindSafe;
// This will panic, but the guard should still clean up
let result = std::panic::catch_unwind(|| {
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
assert_eq!(worker_clone.load(), 1);
panic!("Test panic");
});
}));
// Verify panic occurred
assert!(result.is_err());