[router] move grpc client from router to worker and builder (#10958)
This commit is contained in:
@@ -9,7 +9,7 @@ use serde_json;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
@@ -337,8 +337,8 @@ pub struct BasicWorker {
|
||||
pub consecutive_failures: Arc<AtomicUsize>,
|
||||
pub consecutive_successes: Arc<AtomicUsize>,
|
||||
pub circuit_breaker: CircuitBreaker,
|
||||
/// Optional gRPC client for gRPC workers
|
||||
pub grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
|
||||
/// Lazily initialized gRPC client for gRPC workers
|
||||
pub grpc_client: Arc<RwLock<Option<Arc<Mutex<SglangSchedulerClient>>>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for BasicWorker {
|
||||
@@ -347,7 +347,7 @@ impl fmt::Debug for BasicWorker {
|
||||
.field("metadata", &self.metadata)
|
||||
.field("healthy", &self.healthy.load(Ordering::Relaxed))
|
||||
.field("circuit_breaker", &self.circuit_breaker)
|
||||
.field("has_grpc_client", &self.grpc_client.is_some())
|
||||
.field("grpc_client", &"<RwLock>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -421,7 +421,7 @@ impl Worker for BasicWorker {
|
||||
}
|
||||
}
|
||||
ConnectionMode::Grpc { .. } => {
|
||||
// Use the new get_grpc_client() method
|
||||
// Use the new get_grpc_client() method for lazy initialization
|
||||
match self.get_grpc_client().await {
|
||||
Ok(Some(grpc_client)) => {
|
||||
let mut client = grpc_client.lock().await;
|
||||
@@ -532,19 +532,45 @@ impl Worker for BasicWorker {
|
||||
match self.metadata.connection_mode {
|
||||
ConnectionMode::Http => Ok(None),
|
||||
ConnectionMode::Grpc { .. } => {
|
||||
// If we already have a client, return it
|
||||
if let Some(ref client) = self.grpc_client {
|
||||
{
|
||||
let client_guard = self.grpc_client.read().await;
|
||||
if let Some(ref client) = *client_guard {
|
||||
return Ok(Some(client.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
let mut client_guard = self.grpc_client.write().await;
|
||||
|
||||
if let Some(ref client) = *client_guard {
|
||||
return Ok(Some(client.clone()));
|
||||
}
|
||||
|
||||
// For lazy initialization, we would need to change grpc_client to be mutable
|
||||
// For now, return error if no client exists (will be initialized during worker creation)
|
||||
Err(WorkerError::ConnectionFailed {
|
||||
url: self.metadata.url.clone(),
|
||||
reason:
|
||||
"gRPC client not initialized. Client should be set during worker creation"
|
||||
.to_string(),
|
||||
})
|
||||
tracing::info!(
|
||||
"Lazily initializing gRPC client for worker: {}",
|
||||
self.metadata.url
|
||||
);
|
||||
match SglangSchedulerClient::connect(&self.metadata.url).await {
|
||||
Ok(client) => {
|
||||
let client_arc = Arc::new(Mutex::new(client));
|
||||
*client_guard = Some(client_arc.clone());
|
||||
tracing::info!(
|
||||
"Successfully connected gRPC client for worker: {}",
|
||||
self.metadata.url
|
||||
);
|
||||
Ok(Some(client_arc))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to connect gRPC client for worker {}: {}",
|
||||
self.metadata.url,
|
||||
e
|
||||
);
|
||||
Err(WorkerError::ConnectionFailed {
|
||||
url: self.metadata.url.clone(),
|
||||
reason: format!("Failed to connect to gRPC server: {}", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -553,12 +579,11 @@ impl Worker for BasicWorker {
|
||||
match self.metadata.connection_mode {
|
||||
ConnectionMode::Http => Ok(()),
|
||||
ConnectionMode::Grpc { .. } => {
|
||||
// For now, we can't reset the client since it's not mutable
|
||||
// This would require changing the grpc_client field to use RwLock or OnceCell
|
||||
// which we'll do in a future iteration
|
||||
tracing::warn!(
|
||||
"gRPC client reset not yet implemented - requires mutable client storage"
|
||||
);
|
||||
let mut client_guard = self.grpc_client.write().await;
|
||||
if client_guard.is_some() {
|
||||
tracing::info!("Resetting gRPC client for worker: {}", self.metadata.url);
|
||||
*client_guard = None;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ impl BasicWorkerBuilder {
|
||||
atomic::{AtomicBool, AtomicUsize},
|
||||
Arc,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
let metadata = WorkerMetadata {
|
||||
url: self.url.clone(),
|
||||
@@ -111,6 +111,10 @@ impl BasicWorkerBuilder {
|
||||
health_config: self.health_config,
|
||||
};
|
||||
|
||||
let grpc_client = Arc::new(RwLock::new(
|
||||
self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
|
||||
));
|
||||
|
||||
BasicWorker {
|
||||
metadata,
|
||||
load_counter: Arc::new(AtomicUsize::new(0)),
|
||||
@@ -119,7 +123,7 @@ impl BasicWorkerBuilder {
|
||||
consecutive_failures: Arc::new(AtomicUsize::new(0)),
|
||||
consecutive_successes: Arc::new(AtomicUsize::new(0)),
|
||||
circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config),
|
||||
grpc_client: self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
|
||||
grpc_client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
@@ -18,10 +17,9 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
use tracing::info;
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
@@ -89,86 +87,55 @@ impl GrpcPDRouter {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for prefill workers
|
||||
let mut prefill_grpc_clients = HashMap::new();
|
||||
for (url, _bootstrap_port) in &prefill_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
prefill_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC prefill worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create gRPC clients for decode workers
|
||||
let mut decode_grpc_clients = HashMap::new();
|
||||
for url in &decode_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
decode_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC decode worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Prefill Worker trait objects with gRPC connection mode and register them
|
||||
for (url, bootstrap_port) in &prefill_urls {
|
||||
if let Some(client) = prefill_grpc_clients.remove(url) {
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
})
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc {
|
||||
port: *bootstrap_port,
|
||||
})
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.grpc_client(client)
|
||||
.build();
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
})
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc {
|
||||
port: *bootstrap_port,
|
||||
})
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
// No longer passing pre-initialized client - will be created lazily
|
||||
.build();
|
||||
|
||||
// Register worker in the centralized registry
|
||||
worker_registry.register(Arc::new(worker));
|
||||
}
|
||||
worker_registry.register(Arc::new(worker));
|
||||
info!(
|
||||
"Registered gRPC prefill worker at {} (will connect on first use)",
|
||||
url
|
||||
);
|
||||
}
|
||||
|
||||
// Create Decode Worker trait objects with gRPC connection mode and register them
|
||||
for url in &decode_urls {
|
||||
if let Some(client) = decode_grpc_clients.remove(url) {
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.grpc_client(client)
|
||||
.build();
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.build();
|
||||
|
||||
// Register worker in the centralized registry
|
||||
worker_registry.register(Arc::new(worker));
|
||||
}
|
||||
worker_registry.register(Arc::new(worker));
|
||||
info!(
|
||||
"Registered gRPC decode worker at {} (will connect on first use)",
|
||||
url
|
||||
);
|
||||
}
|
||||
|
||||
if prefill_urls.is_empty() && decode_urls.is_empty() {
|
||||
return Err("No gRPC workers configured".to_string());
|
||||
}
|
||||
|
||||
// Initialize policies with workers if needed - filter for gRPC workers only
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -96,51 +95,35 @@ impl GrpcRouter {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for each worker
|
||||
let mut grpc_clients = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Get registries from context
|
||||
let worker_registry = ctx.worker_registry.clone();
|
||||
let policy_registry = ctx.policy_registry.clone();
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode and register them
|
||||
// Workers will lazily initialize their gRPC clients on first use
|
||||
for url in &worker_urls {
|
||||
if let Some(client) = grpc_clients.remove(url) {
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.grpc_client(client)
|
||||
.build();
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.build();
|
||||
|
||||
// Register worker in the centralized registry
|
||||
worker_registry.register(Arc::new(worker));
|
||||
} else {
|
||||
warn!("No gRPC client for worker {}, skipping", url);
|
||||
}
|
||||
worker_registry.register(Arc::new(worker));
|
||||
info!(
|
||||
"Registered gRPC worker at {} (will connect on first use)",
|
||||
url
|
||||
);
|
||||
}
|
||||
|
||||
if worker_urls.is_empty() {
|
||||
return Err("No gRPC workers configured".to_string());
|
||||
}
|
||||
|
||||
// Get only gRPC workers from registry for policy initialization
|
||||
|
||||
Reference in New Issue
Block a user