diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index bf61e83a5..3970346c5 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -9,7 +9,7 @@ use serde_json; use std::fmt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; static WORKER_CLIENT: LazyLock = LazyLock::new(|| { reqwest::Client::builder() @@ -337,8 +337,8 @@ pub struct BasicWorker { pub consecutive_failures: Arc, pub consecutive_successes: Arc, pub circuit_breaker: CircuitBreaker, - /// Optional gRPC client for gRPC workers - pub grpc_client: Option>>, + /// Lazily initialized gRPC client for gRPC workers + pub grpc_client: Arc>>>>, } impl fmt::Debug for BasicWorker { @@ -347,7 +347,7 @@ impl fmt::Debug for BasicWorker { .field("metadata", &self.metadata) .field("healthy", &self.healthy.load(Ordering::Relaxed)) .field("circuit_breaker", &self.circuit_breaker) - .field("has_grpc_client", &self.grpc_client.is_some()) + .field("grpc_client", &"") .finish() } } @@ -421,7 +421,7 @@ impl Worker for BasicWorker { } } ConnectionMode::Grpc { .. } => { - // Use the new get_grpc_client() method + // Use the new get_grpc_client() method for lazy initialization match self.get_grpc_client().await { Ok(Some(grpc_client)) => { let mut client = grpc_client.lock().await; @@ -532,19 +532,45 @@ impl Worker for BasicWorker { match self.metadata.connection_mode { ConnectionMode::Http => Ok(None), ConnectionMode::Grpc { .. } => { - // If we already have a client, return it - if let Some(ref client) = self.grpc_client { + { + let client_guard = self.grpc_client.read().await; + if let Some(ref client) = *client_guard { + return Ok(Some(client.clone())); + } + } + + let mut client_guard = self.grpc_client.write().await; + + if let Some(ref client) = *client_guard { return Ok(Some(client.clone())); } - // For lazy initialization, we would need to change grpc_client to be mutable - // For now, return error if no client exists (will be initialized during worker creation) - Err(WorkerError::ConnectionFailed { - url: self.metadata.url.clone(), - reason: - "gRPC client not initialized. Client should be set during worker creation" - .to_string(), - }) + tracing::info!( + "Lazily initializing gRPC client for worker: {}", + self.metadata.url + ); + match SglangSchedulerClient::connect(&self.metadata.url).await { + Ok(client) => { + let client_arc = Arc::new(Mutex::new(client)); + *client_guard = Some(client_arc.clone()); + tracing::info!( + "Successfully connected gRPC client for worker: {}", + self.metadata.url + ); + Ok(Some(client_arc)) + } + Err(e) => { + tracing::error!( + "Failed to connect gRPC client for worker {}: {}", + self.metadata.url, + e + ); + Err(WorkerError::ConnectionFailed { + url: self.metadata.url.clone(), + reason: format!("Failed to connect to gRPC server: {}", e), + }) + } + } } } } @@ -553,12 +579,11 @@ impl Worker for BasicWorker { match self.metadata.connection_mode { ConnectionMode::Http => Ok(()), ConnectionMode::Grpc { .. } => { - // For now, we can't reset the client since it's not mutable - // This would require changing the grpc_client field to use RwLock or OnceCell - // which we'll do in a future iteration - tracing::warn!( - "gRPC client reset not yet implemented - requires mutable client storage" - ); + let mut client_guard = self.grpc_client.write().await; + if client_guard.is_some() { + tracing::info!("Resetting gRPC client for worker: {}", self.metadata.url); + *client_guard = None; + } Ok(()) } } diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 9dd03b30a..4e156bb42 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -100,7 +100,7 @@ impl BasicWorkerBuilder { atomic::{AtomicBool, AtomicUsize}, Arc, }; - use tokio::sync::Mutex; + use tokio::sync::{Mutex, RwLock}; let metadata = WorkerMetadata { url: self.url.clone(), @@ -111,6 +111,10 @@ impl BasicWorkerBuilder { health_config: self.health_config, }; + let grpc_client = Arc::new(RwLock::new( + self.grpc_client.map(|client| Arc::new(Mutex::new(client))), + )); + BasicWorker { metadata, load_counter: Arc::new(AtomicUsize::new(0)), @@ -119,7 +123,7 @@ impl BasicWorkerBuilder { consecutive_failures: Arc::new(AtomicUsize::new(0)), consecutive_successes: Arc::new(AtomicUsize::new(0)), circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config), - grpc_client: self.grpc_client.map(|client| Arc::new(Mutex::new(client))), + grpc_client, } } } diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index d28560dbf..6c259456a 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -4,7 +4,6 @@ use crate::config::types::RetryConfig; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; -use crate::grpc_client::SglangSchedulerClient; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::reasoning_parser::ParserFactory; @@ -18,10 +17,9 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; -use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tracing::{info, warn}; +use tracing::info; /// gRPC PD (Prefill-Decode) router implementation for SGLang #[allow(dead_code)] // Fields will be used once implementation is complete @@ -89,86 +87,55 @@ impl GrpcPDRouter { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Create gRPC clients for prefill workers - let mut prefill_grpc_clients = HashMap::new(); - for (url, _bootstrap_port) in &prefill_urls { - match SglangSchedulerClient::connect(url).await { - Ok(client) => { - prefill_grpc_clients.insert(url.clone(), client); - info!("Connected to gRPC prefill worker at {}", url); - } - Err(e) => { - warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e); - // Continue with other workers - } - } - } - - // Create gRPC clients for decode workers - let mut decode_grpc_clients = HashMap::new(); - for url in &decode_urls { - match SglangSchedulerClient::connect(url).await { - Ok(client) => { - decode_grpc_clients.insert(url.clone(), client); - info!("Connected to gRPC decode worker at {}", url); - } - Err(e) => { - warn!("Failed to connect to gRPC decode worker at {}: {}", url, e); - // Continue with other workers - } - } - } - - if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() { - return Err("Failed to connect to any gRPC workers".to_string()); - } - - // Create Prefill Worker trait objects with gRPC connection mode and register them for (url, bootstrap_port) in &prefill_urls { - if let Some(client) = prefill_grpc_clients.remove(url) { - let worker = BasicWorkerBuilder::new(url.clone()) - .worker_type(WorkerType::Prefill { - bootstrap_port: *bootstrap_port, - }) - .connection_mode(crate::core::ConnectionMode::Grpc { - port: *bootstrap_port, - }) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .grpc_client(client) - .build(); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Prefill { + bootstrap_port: *bootstrap_port, + }) + .connection_mode(crate::core::ConnectionMode::Grpc { + port: *bootstrap_port, + }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + // No longer passing pre-initialized client - will be created lazily + .build(); - // Register worker in the centralized registry - worker_registry.register(Arc::new(worker)); - } + worker_registry.register(Arc::new(worker)); + info!( + "Registered gRPC prefill worker at {} (will connect on first use)", + url + ); } - // Create Decode Worker trait objects with gRPC connection mode and register them for url in &decode_urls { - if let Some(client) = decode_grpc_clients.remove(url) { - let worker = BasicWorkerBuilder::new(url.clone()) - .worker_type(WorkerType::Decode) - .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .grpc_client(client) - .build(); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Decode) + .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .build(); - // Register worker in the centralized registry - worker_registry.register(Arc::new(worker)); - } + worker_registry.register(Arc::new(worker)); + info!( + "Registered gRPC decode worker at {} (will connect on first use)", + url + ); + } + + if prefill_urls.is_empty() && decode_urls.is_empty() { + return Err("No gRPC workers configured".to_string()); } // Initialize policies with workers if needed - filter for gRPC workers only diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 505fee145..de8e04f86 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -1,6 +1,5 @@ // gRPC Router Implementation -use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -96,51 +95,35 @@ impl GrpcRouter { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Create gRPC clients for each worker - let mut grpc_clients = HashMap::new(); - for url in &worker_urls { - match SglangSchedulerClient::connect(url).await { - Ok(client) => { - grpc_clients.insert(url.clone(), client); - info!("Connected to gRPC worker at {}", url); - } - Err(e) => { - warn!("Failed to connect to gRPC worker at {}: {}", url, e); - // Continue with other workers - } - } - } - - if grpc_clients.is_empty() { - return Err("Failed to connect to any gRPC workers".to_string()); - } - // Get registries from context let worker_registry = ctx.worker_registry.clone(); let policy_registry = ctx.policy_registry.clone(); // Create Worker trait objects with gRPC connection mode and register them + // Workers will lazily initialize their gRPC clients on first use for url in &worker_urls { - if let Some(client) = grpc_clients.remove(url) { - let worker = BasicWorkerBuilder::new(url.clone()) - .worker_type(WorkerType::Regular) - .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .grpc_client(client) - .build(); + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Regular) + .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(HealthConfig { + timeout_secs: ctx.router_config.health_check.timeout_secs, + check_interval_secs: ctx.router_config.health_check.check_interval_secs, + endpoint: ctx.router_config.health_check.endpoint.clone(), + failure_threshold: ctx.router_config.health_check.failure_threshold, + success_threshold: ctx.router_config.health_check.success_threshold, + }) + .build(); - // Register worker in the centralized registry - worker_registry.register(Arc::new(worker)); - } else { - warn!("No gRPC client for worker {}, skipping", url); - } + worker_registry.register(Arc::new(worker)); + info!( + "Registered gRPC worker at {} (will connect on first use)", + url + ); + } + + if worker_urls.is_empty() { + return Err("No gRPC workers configured".to_string()); } // Get only gRPC workers from registry for policy initialization