Grpc client (#9939)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use async_trait::async_trait;
|
||||
use futures;
|
||||
@@ -6,6 +7,7 @@ use serde_json;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// Shared HTTP client for worker operations (health checks, server info, etc.)
|
||||
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
@@ -249,7 +251,7 @@ pub struct WorkerMetadata {
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct BasicWorker {
|
||||
metadata: WorkerMetadata,
|
||||
load_counter: Arc<AtomicUsize>,
|
||||
@@ -258,6 +260,19 @@ pub struct BasicWorker {
|
||||
consecutive_failures: Arc<AtomicUsize>,
|
||||
consecutive_successes: Arc<AtomicUsize>,
|
||||
circuit_breaker: CircuitBreaker,
|
||||
/// Optional gRPC client for gRPC workers
|
||||
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for BasicWorker {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BasicWorker")
|
||||
.field("metadata", &self.metadata)
|
||||
.field("healthy", &self.healthy.load(Ordering::Relaxed))
|
||||
.field("circuit_breaker", &self.circuit_breaker)
|
||||
.field("has_grpc_client", &self.grpc_client.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BasicWorker {
|
||||
@@ -286,6 +301,7 @@ impl BasicWorker {
|
||||
consecutive_failures: Arc::new(AtomicUsize::new(0)),
|
||||
consecutive_successes: Arc::new(AtomicUsize::new(0)),
|
||||
circuit_breaker: CircuitBreaker::new(),
|
||||
grpc_client: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,6 +320,12 @@ impl BasicWorker {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the gRPC client for gRPC workers
|
||||
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
|
||||
self.grpc_client = Some(Arc::new(Mutex::new(client)));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||
if self.url().contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
@@ -352,15 +374,46 @@ impl Worker for BasicWorker {
|
||||
async fn check_health_async(&self) -> WorkerResult<()> {
|
||||
use std::time::Duration;
|
||||
|
||||
// Perform actual HTTP health check
|
||||
let url = self.normalised_url()?;
|
||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
let health_result = match &self.metadata.connection_mode {
|
||||
ConnectionMode::Http => {
|
||||
// Perform HTTP health check
|
||||
let url = self.normalised_url()?;
|
||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
|
||||
// Use the shared client with a custom timeout for this request
|
||||
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
// Use the shared client with a custom timeout for this request
|
||||
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
ConnectionMode::Grpc { .. } => {
|
||||
// Perform gRPC health check
|
||||
if let Some(grpc_client) = &self.grpc_client {
|
||||
let mut client = grpc_client.lock().await;
|
||||
match client.health_check().await {
|
||||
Ok(response) => {
|
||||
tracing::debug!(
|
||||
"gRPC health check succeeded for {}: healthy={}",
|
||||
self.metadata.url,
|
||||
response.healthy
|
||||
);
|
||||
response.healthy
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"gRPC health check RPC failed for {}: {:?}",
|
||||
self.metadata.url,
|
||||
e
|
||||
);
|
||||
false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::error!("No gRPC client available for worker {}", self.metadata.url);
|
||||
false
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if health_result {
|
||||
@@ -390,7 +443,7 @@ impl Worker for BasicWorker {
|
||||
}
|
||||
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: url.to_string(),
|
||||
url: self.metadata.url.clone(),
|
||||
reason: format!("Health check failed (consecutive failures: {})", failures),
|
||||
})
|
||||
}
|
||||
@@ -1491,12 +1544,17 @@ mod tests {
|
||||
// Clone for use inside catch_unwind
|
||||
let worker_clone = Arc::clone(&worker);
|
||||
|
||||
// Use AssertUnwindSafe wrapper for the test
|
||||
// This is safe because we're only testing the load counter behavior,
|
||||
// not the grpc_client which is None for HTTP workers
|
||||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
// This will panic, but the guard should still clean up
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
|
||||
assert_eq!(worker_clone.load(), 1);
|
||||
panic!("Test panic");
|
||||
});
|
||||
}));
|
||||
|
||||
// Verify panic occurred
|
||||
assert!(result.is_err());
|
||||
|
||||
@@ -20,7 +20,14 @@ impl SglangSchedulerClient {
|
||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||
|
||||
let channel = Channel::from_shared(endpoint.to_string())?
|
||||
// Convert grpc:// to http:// for tonic
|
||||
let http_endpoint = if endpoint.starts_with("grpc://") {
|
||||
endpoint.replace("grpc://", "http://")
|
||||
} else {
|
||||
endpoint.to_string()
|
||||
};
|
||||
|
||||
let channel = Channel::from_shared(http_endpoint)?
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect()
|
||||
.await?;
|
||||
@@ -59,11 +66,13 @@ impl SglangSchedulerClient {
|
||||
pub async fn health_check(
|
||||
&mut self,
|
||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||
debug!("Sending health check request");
|
||||
let request = Request::new(proto::HealthCheckRequest {
|
||||
include_detailed_metrics: false,
|
||||
});
|
||||
|
||||
let response = self.client.health_check(request).await?;
|
||||
debug!("Health check response received");
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
|
||||
@@ -108,9 +108,11 @@ impl GrpcRouter {
|
||||
}
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
|
||||
|
||||
// Move clients from the HashMap to the workers
|
||||
for url in &worker_urls {
|
||||
if let Some(client) = grpc_clients.remove(url) {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Regular,
|
||||
@@ -123,10 +125,14 @@ impl GrpcRouter {
|
||||
endpoint: health_check_config.endpoint.clone(),
|
||||
failure_threshold: health_check_config.failure_threshold,
|
||||
success_threshold: health_check_config.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
})
|
||||
.with_grpc_client(client);
|
||||
|
||||
workers.push(Box::new(worker) as Box<dyn Worker>);
|
||||
} else {
|
||||
warn!("No gRPC client for worker {}, skipping", url);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize policy with workers if needed
|
||||
if let Some(cache_aware) = policy
|
||||
@@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter {
|
||||
fn remove_worker(&self, _worker_url: &str) {}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
vec![]
|
||||
self.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user