diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs index 04fa40c90..fbe033590 100644 --- a/sgl-router/src/core/error.rs +++ b/sgl-router/src/core/error.rs @@ -19,6 +19,8 @@ pub enum WorkerError { WorkerAtCapacity { url: String }, /// Invalid URL format InvalidUrl { url: String }, + /// Connection failed + ConnectionFailed { url: String, reason: String }, } impl fmt::Display for WorkerError { @@ -42,6 +44,9 @@ impl fmt::Display for WorkerError { WorkerError::InvalidUrl { url } => { write!(f, "Invalid URL format: {}", url) } + WorkerError::ConnectionFailed { url, reason } => { + write!(f, "Connection failed for worker {}: {}", url, reason) + } } } } diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 722510fc2..bf61e83a5 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -220,6 +220,16 @@ pub trait Worker: Send + Sync + fmt::Debug { .get("chat_template") .map(|s| s.as_str()) } + + /// Get or create a gRPC client for this worker + /// Returns None for HTTP workers, Some(client) for gRPC workers + async fn get_grpc_client(&self) -> WorkerResult>>>; + + /// Reset the gRPC client connection (for reconnection scenarios) + /// No-op for HTTP workers + async fn reset_grpc_client(&self) -> WorkerResult<()> { + Ok(()) + } } /// Connection mode for worker communication @@ -411,29 +421,44 @@ impl Worker for BasicWorker { } } ConnectionMode::Grpc { .. } => { - if let Some(grpc_client) = &self.grpc_client { - let mut client = grpc_client.lock().await; - match client.health_check().await { - Ok(response) => { - tracing::debug!( - "gRPC health check succeeded for {}: healthy={}", - self.metadata.url, + // Use the new get_grpc_client() method + match self.get_grpc_client().await { + Ok(Some(grpc_client)) => { + let mut client = grpc_client.lock().await; + match client.health_check().await { + Ok(response) => { + tracing::debug!( + "gRPC health check succeeded for {}: healthy={}", + self.metadata.url, + response.healthy + ); response.healthy - ); - response.healthy - } - Err(e) => { - tracing::warn!( - "gRPC health check RPC failed for {}: {:?}", - self.metadata.url, - e - ); - false + } + Err(e) => { + tracing::warn!( + "gRPC health check RPC failed for {}: {:?}", + self.metadata.url, + e + ); + false + } } } - } else { - tracing::error!("No gRPC client available for worker {}", self.metadata.url); - false + Ok(None) => { + tracing::error!( + "Worker {} is not a gRPC worker but has gRPC connection mode", + self.metadata.url + ); + false + } + Err(e) => { + tracing::error!( + "Failed to get gRPC client for worker {}: {:?}", + self.metadata.url, + e + ); + false + } } } }; @@ -502,6 +527,42 @@ impl Worker for BasicWorker { fn circuit_breaker(&self) -> &CircuitBreaker { &self.circuit_breaker } + + async fn get_grpc_client(&self) -> WorkerResult>>> { + match self.metadata.connection_mode { + ConnectionMode::Http => Ok(None), + ConnectionMode::Grpc { .. } => { + // If we already have a client, return it + if let Some(ref client) = self.grpc_client { + return Ok(Some(client.clone())); + } + + // For lazy initialization, we would need to change grpc_client to be mutable + // For now, return error if no client exists (will be initialized during worker creation) + Err(WorkerError::ConnectionFailed { + url: self.metadata.url.clone(), + reason: + "gRPC client not initialized. Client should be set during worker creation" + .to_string(), + }) + } + } + } + + async fn reset_grpc_client(&self) -> WorkerResult<()> { + match self.metadata.connection_mode { + ConnectionMode::Http => Ok(()), + ConnectionMode::Grpc { .. } => { + // For now, we can't reset the client since it's not mutable + // This would require changing the grpc_client field to use RwLock or OnceCell + // which we'll do in a future iteration + tracing::warn!( + "gRPC client reset not yet implemented - requires mutable client storage" + ); + Ok(()) + } + } + } } /// A DP-aware worker that handles data-parallel routing @@ -630,6 +691,14 @@ impl Worker for DPAwareWorker { fn endpoint_url(&self, route: &str) -> String { format!("{}{}", self.base_url, route) } + + async fn get_grpc_client(&self) -> WorkerResult>>> { + self.base_worker.get_grpc_client().await + } + + async fn reset_grpc_client(&self) -> WorkerResult<()> { + self.base_worker.reset_grpc_client().await + } } /// Worker factory for creating workers of different types diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index d0f3c4c26..e6269d9dc 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -13,6 +13,7 @@ pub mod proto { // package sglang.grpc.scheduler; generates a nested module structure /// gRPC client for SGLang scheduler +#[derive(Clone)] pub struct SglangSchedulerClient { client: proto::sglang_scheduler_client::SglangSchedulerClient, } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f630100c6..505fee145 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -202,12 +202,23 @@ impl GrpcRouter { debug!("Selected worker: {}", worker.url()); - // Step 2: Get gRPC client for worker (fail fast if can't connect) - // TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here) - let client = match self.get_or_create_grpc_client(worker.url()).await { - Ok(c) => c, + // Step 2: Get gRPC client from worker + let client = match worker.get_grpc_client().await { + Ok(Some(client_arc)) => { + // Clone the client from inside the Arc> + let client = client_arc.lock().await.clone(); + client + } + Ok(None) => { + error!("Selected worker is not a gRPC worker"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Selected worker is not configured for gRPC", + ) + .into_response(); + } Err(e) => { - error!("Failed to get gRPC client: {}", e); + error!("Failed to get gRPC client from worker: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to get gRPC client: {}", e), @@ -552,18 +563,6 @@ impl GrpcRouter { None } - /// Get or create a gRPC client for the worker - async fn get_or_create_grpc_client( - &self, - worker_url: &str, - ) -> Result { - // TODO: move to worker - debug!("Creating new gRPC client for worker: {}", worker_url); - SglangSchedulerClient::connect(worker_url) - .await - .map_err(|e| format!("Failed to connect to gRPC server: {}", e)) - } - /// Placeholder for streaming handler (to be implemented in Phase 2) async fn handle_streaming_chat( &self,