Grpc client (#9939)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
|
||||||
|
use crate::grpc::SglangSchedulerClient;
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures;
|
use futures;
|
||||||
@@ -6,6 +7,7 @@ use serde_json;
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||||
use std::sync::{Arc, LazyLock};
|
use std::sync::{Arc, LazyLock};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
// Shared HTTP client for worker operations (health checks, server info, etc.)
|
// Shared HTTP client for worker operations (health checks, server info, etc.)
|
||||||
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||||
@@ -249,7 +251,7 @@ pub struct WorkerMetadata {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Basic worker implementation
|
/// Basic worker implementation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Clone)]
|
||||||
pub struct BasicWorker {
|
pub struct BasicWorker {
|
||||||
metadata: WorkerMetadata,
|
metadata: WorkerMetadata,
|
||||||
load_counter: Arc<AtomicUsize>,
|
load_counter: Arc<AtomicUsize>,
|
||||||
@@ -258,6 +260,19 @@ pub struct BasicWorker {
|
|||||||
consecutive_failures: Arc<AtomicUsize>,
|
consecutive_failures: Arc<AtomicUsize>,
|
||||||
consecutive_successes: Arc<AtomicUsize>,
|
consecutive_successes: Arc<AtomicUsize>,
|
||||||
circuit_breaker: CircuitBreaker,
|
circuit_breaker: CircuitBreaker,
|
||||||
|
/// Optional gRPC client for gRPC workers
|
||||||
|
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for BasicWorker {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("BasicWorker")
|
||||||
|
.field("metadata", &self.metadata)
|
||||||
|
.field("healthy", &self.healthy.load(Ordering::Relaxed))
|
||||||
|
.field("circuit_breaker", &self.circuit_breaker)
|
||||||
|
.field("has_grpc_client", &self.grpc_client.is_some())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BasicWorker {
|
impl BasicWorker {
|
||||||
@@ -286,6 +301,7 @@ impl BasicWorker {
|
|||||||
consecutive_failures: Arc::new(AtomicUsize::new(0)),
|
consecutive_failures: Arc::new(AtomicUsize::new(0)),
|
||||||
consecutive_successes: Arc::new(AtomicUsize::new(0)),
|
consecutive_successes: Arc::new(AtomicUsize::new(0)),
|
||||||
circuit_breaker: CircuitBreaker::new(),
|
circuit_breaker: CircuitBreaker::new(),
|
||||||
|
grpc_client: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,6 +320,12 @@ impl BasicWorker {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the gRPC client for gRPC workers
|
||||||
|
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
|
||||||
|
self.grpc_client = Some(Arc::new(Mutex::new(client)));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||||
if self.url().contains("@") {
|
if self.url().contains("@") {
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
@@ -352,15 +374,46 @@ impl Worker for BasicWorker {
|
|||||||
async fn check_health_async(&self) -> WorkerResult<()> {
|
async fn check_health_async(&self) -> WorkerResult<()> {
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
// Perform actual HTTP health check
|
let health_result = match &self.metadata.connection_mode {
|
||||||
let url = self.normalised_url()?;
|
ConnectionMode::Http => {
|
||||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
// Perform HTTP health check
|
||||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
let url = self.normalised_url()?;
|
||||||
|
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||||
|
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||||
|
|
||||||
// Use the shared client with a custom timeout for this request
|
// Use the shared client with a custom timeout for this request
|
||||||
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
||||||
Ok(response) => response.status().is_success(),
|
Ok(response) => response.status().is_success(),
|
||||||
Err(_) => false,
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConnectionMode::Grpc { .. } => {
|
||||||
|
// Perform gRPC health check
|
||||||
|
if let Some(grpc_client) = &self.grpc_client {
|
||||||
|
let mut client = grpc_client.lock().await;
|
||||||
|
match client.health_check().await {
|
||||||
|
Ok(response) => {
|
||||||
|
tracing::debug!(
|
||||||
|
"gRPC health check succeeded for {}: healthy={}",
|
||||||
|
self.metadata.url,
|
||||||
|
response.healthy
|
||||||
|
);
|
||||||
|
response.healthy
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"gRPC health check RPC failed for {}: {:?}",
|
||||||
|
self.metadata.url,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::error!("No gRPC client available for worker {}", self.metadata.url);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if health_result {
|
if health_result {
|
||||||
@@ -390,7 +443,7 @@ impl Worker for BasicWorker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Err(WorkerError::HealthCheckFailed {
|
Err(WorkerError::HealthCheckFailed {
|
||||||
url: url.to_string(),
|
url: self.metadata.url.clone(),
|
||||||
reason: format!("Health check failed (consecutive failures: {})", failures),
|
reason: format!("Health check failed (consecutive failures: {})", failures),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1491,12 +1544,17 @@ mod tests {
|
|||||||
// Clone for use inside catch_unwind
|
// Clone for use inside catch_unwind
|
||||||
let worker_clone = Arc::clone(&worker);
|
let worker_clone = Arc::clone(&worker);
|
||||||
|
|
||||||
|
// Use AssertUnwindSafe wrapper for the test
|
||||||
|
// This is safe because we're only testing the load counter behavior,
|
||||||
|
// not the grpc_client which is None for HTTP workers
|
||||||
|
use std::panic::AssertUnwindSafe;
|
||||||
|
|
||||||
// This will panic, but the guard should still clean up
|
// This will panic, but the guard should still clean up
|
||||||
let result = std::panic::catch_unwind(|| {
|
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||||
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
|
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
|
||||||
assert_eq!(worker_clone.load(), 1);
|
assert_eq!(worker_clone.load(), 1);
|
||||||
panic!("Test panic");
|
panic!("Test panic");
|
||||||
});
|
}));
|
||||||
|
|
||||||
// Verify panic occurred
|
// Verify panic occurred
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|||||||
@@ -20,7 +20,14 @@ impl SglangSchedulerClient {
|
|||||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||||
|
|
||||||
let channel = Channel::from_shared(endpoint.to_string())?
|
// Convert grpc:// to http:// for tonic
|
||||||
|
let http_endpoint = if endpoint.starts_with("grpc://") {
|
||||||
|
endpoint.replace("grpc://", "http://")
|
||||||
|
} else {
|
||||||
|
endpoint.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let channel = Channel::from_shared(http_endpoint)?
|
||||||
.timeout(Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.connect()
|
.connect()
|
||||||
.await?;
|
.await?;
|
||||||
@@ -59,11 +66,13 @@ impl SglangSchedulerClient {
|
|||||||
pub async fn health_check(
|
pub async fn health_check(
|
||||||
&mut self,
|
&mut self,
|
||||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||||
|
debug!("Sending health check request");
|
||||||
let request = Request::new(proto::HealthCheckRequest {
|
let request = Request::new(proto::HealthCheckRequest {
|
||||||
include_detailed_metrics: false,
|
include_detailed_metrics: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
let response = self.client.health_check(request).await?;
|
let response = self.client.health_check(request).await?;
|
||||||
|
debug!("Health check response received");
|
||||||
Ok(response.into_inner())
|
Ok(response.into_inner())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -108,9 +108,11 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create Worker trait objects with gRPC connection mode
|
// Create Worker trait objects with gRPC connection mode
|
||||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
|
||||||
.iter()
|
|
||||||
.map(|url| {
|
// Move clients from the HashMap to the workers
|
||||||
|
for url in &worker_urls {
|
||||||
|
if let Some(client) = grpc_clients.remove(url) {
|
||||||
let worker = BasicWorker::with_connection_mode(
|
let worker = BasicWorker::with_connection_mode(
|
||||||
url.clone(),
|
url.clone(),
|
||||||
WorkerType::Regular,
|
WorkerType::Regular,
|
||||||
@@ -123,10 +125,14 @@ impl GrpcRouter {
|
|||||||
endpoint: health_check_config.endpoint.clone(),
|
endpoint: health_check_config.endpoint.clone(),
|
||||||
failure_threshold: health_check_config.failure_threshold,
|
failure_threshold: health_check_config.failure_threshold,
|
||||||
success_threshold: health_check_config.success_threshold,
|
success_threshold: health_check_config.success_threshold,
|
||||||
});
|
})
|
||||||
Box::new(worker) as Box<dyn Worker>
|
.with_grpc_client(client);
|
||||||
})
|
|
||||||
.collect();
|
workers.push(Box::new(worker) as Box<dyn Worker>);
|
||||||
|
} else {
|
||||||
|
warn!("No gRPC client for worker {}, skipping", url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize policy with workers if needed
|
// Initialize policy with workers if needed
|
||||||
if let Some(cache_aware) = policy
|
if let Some(cache_aware) = policy
|
||||||
@@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter {
|
|||||||
fn remove_worker(&self, _worker_url: &str) {}
|
fn remove_worker(&self, _worker_url: &str) {}
|
||||||
|
|
||||||
fn get_worker_urls(&self) -> Vec<String> {
|
fn get_worker_urls(&self) -> Vec<String> {
|
||||||
vec![]
|
self.workers
|
||||||
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user