From 11dcabc5459180e273acee16e20aac2d2f2ec056 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Tue, 2 Sep 2025 11:47:35 -0700 Subject: [PATCH] Grpc client (#9939) --- sgl-router/src/core/worker.rs | 82 +++++++++++++++++++++++---- sgl-router/src/grpc/client.rs | 11 +++- sgl-router/src/routers/grpc/router.rs | 27 ++++++--- 3 files changed, 99 insertions(+), 21 deletions(-) diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index b054355f0..f25fc6eea 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,4 +1,5 @@ use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; +use crate::grpc::SglangSchedulerClient; use crate::metrics::RouterMetrics; use async_trait::async_trait; use futures; @@ -6,6 +7,7 @@ use serde_json; use std::fmt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; // Shared HTTP client for worker operations (health checks, server info, etc.) static WORKER_CLIENT: LazyLock = LazyLock::new(|| { @@ -249,7 +251,7 @@ pub struct WorkerMetadata { } /// Basic worker implementation -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct BasicWorker { metadata: WorkerMetadata, load_counter: Arc, @@ -258,6 +260,19 @@ pub struct BasicWorker { consecutive_failures: Arc, consecutive_successes: Arc, circuit_breaker: CircuitBreaker, + /// Optional gRPC client for gRPC workers + grpc_client: Option>>, +} + +impl fmt::Debug for BasicWorker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BasicWorker") + .field("metadata", &self.metadata) + .field("healthy", &self.healthy.load(Ordering::Relaxed)) + .field("circuit_breaker", &self.circuit_breaker) + .field("has_grpc_client", &self.grpc_client.is_some()) + .finish() + } } impl BasicWorker { @@ -286,6 +301,7 @@ impl BasicWorker { consecutive_failures: Arc::new(AtomicUsize::new(0)), consecutive_successes: Arc::new(AtomicUsize::new(0)), circuit_breaker: CircuitBreaker::new(), + grpc_client: None, } } @@ -304,6 +320,12 @@ impl BasicWorker { self } + /// Set the gRPC client for gRPC workers + pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self { + self.grpc_client = Some(Arc::new(Mutex::new(client))); + self + } + pub fn normalised_url(&self) -> WorkerResult<&str> { if self.url().contains("@") { // Need to extract the URL from "http://host:port@dp_rank" @@ -352,15 +374,46 @@ impl Worker for BasicWorker { async fn check_health_async(&self) -> WorkerResult<()> { use std::time::Duration; - // Perform actual HTTP health check - let url = self.normalised_url()?; - let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); - let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); + let health_result = match &self.metadata.connection_mode { + ConnectionMode::Http => { + // Perform HTTP health check + let url = self.normalised_url()?; + let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); + let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); - // Use the shared client with a custom timeout for this request - let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await { - Ok(response) => response.status().is_success(), - Err(_) => false, + // Use the shared client with a custom timeout for this request + match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await { + Ok(response) => response.status().is_success(), + Err(_) => false, + } + } + ConnectionMode::Grpc { .. } => { + // Perform gRPC health check + if let Some(grpc_client) = &self.grpc_client { + let mut client = grpc_client.lock().await; + match client.health_check().await { + Ok(response) => { + tracing::debug!( + "gRPC health check succeeded for {}: healthy={}", + self.metadata.url, + response.healthy + ); + response.healthy + } + Err(e) => { + tracing::warn!( + "gRPC health check RPC failed for {}: {:?}", + self.metadata.url, + e + ); + false + } + } + } else { + tracing::error!("No gRPC client available for worker {}", self.metadata.url); + false + } + } }; if health_result { @@ -390,7 +443,7 @@ impl Worker for BasicWorker { } Err(WorkerError::HealthCheckFailed { - url: url.to_string(), + url: self.metadata.url.clone(), reason: format!("Health check failed (consecutive failures: {})", failures), }) } @@ -1491,12 +1544,17 @@ mod tests { // Clone for use inside catch_unwind let worker_clone = Arc::clone(&worker); + // Use AssertUnwindSafe wrapper for the test + // This is safe because we're only testing the load counter behavior, + // not the grpc_client which is None for HTTP workers + use std::panic::AssertUnwindSafe; + // This will panic, but the guard should still clean up - let result = std::panic::catch_unwind(|| { + let result = std::panic::catch_unwind(AssertUnwindSafe(|| { let _guard = WorkerLoadGuard::new(worker_clone.as_ref()); assert_eq!(worker_clone.load(), 1); panic!("Test panic"); - }); + })); // Verify panic occurred assert!(result.is_err()); diff --git a/sgl-router/src/grpc/client.rs b/sgl-router/src/grpc/client.rs index f31227bb1..8561b79db 100644 --- a/sgl-router/src/grpc/client.rs +++ b/sgl-router/src/grpc/client.rs @@ -20,7 +20,14 @@ impl SglangSchedulerClient { pub async fn connect(endpoint: &str) -> Result> { debug!("Connecting to SGLang scheduler at {}", endpoint); - let channel = Channel::from_shared(endpoint.to_string())? + // Convert grpc:// to http:// for tonic + let http_endpoint = if endpoint.starts_with("grpc://") { + endpoint.replace("grpc://", "http://") + } else { + endpoint.to_string() + }; + + let channel = Channel::from_shared(http_endpoint)? .timeout(Duration::from_secs(30)) .connect() .await?; @@ -59,11 +66,13 @@ impl SglangSchedulerClient { pub async fn health_check( &mut self, ) -> Result> { + debug!("Sending health check request"); let request = Request::new(proto::HealthCheckRequest { include_detailed_metrics: false, }); let response = self.client.health_check(request).await?; + debug!("Health check response received"); Ok(response.into_inner()) } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index e7a0bd162..f81a25917 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -108,9 +108,11 @@ impl GrpcRouter { } // Create Worker trait objects with gRPC connection mode - let workers: Vec> = worker_urls - .iter() - .map(|url| { + let mut workers: Vec> = Vec::new(); + + // Move clients from the HashMap to the workers + for url in &worker_urls { + if let Some(client) = grpc_clients.remove(url) { let worker = BasicWorker::with_connection_mode( url.clone(), WorkerType::Regular, @@ -123,10 +125,14 @@ impl GrpcRouter { endpoint: health_check_config.endpoint.clone(), failure_threshold: health_check_config.failure_threshold, success_threshold: health_check_config.success_threshold, - }); - Box::new(worker) as Box - }) - .collect(); + }) + .with_grpc_client(client); + + workers.push(Box::new(worker) as Box); + } else { + warn!("No gRPC client for worker {}, skipping", url); + } + } // Initialize policy with workers if needed if let Some(cache_aware) = policy @@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter { fn remove_worker(&self, _worker_url: &str) {} fn get_worker_urls(&self) -> Vec { - vec![] + self.workers + .read() + .unwrap() + .iter() + .map(|w| w.url().to_string()) + .collect() } }