[router] add grpc pd and regular router init (#9893)

This commit is contained in:
Chang Su
2025-09-01 20:06:15 -07:00
committed by GitHub
parent b5245064f6
commit 9a0cac1be0
14 changed files with 783 additions and 58 deletions

View File

@@ -1,7 +1,19 @@
// PD (Prefill-Decode) gRPC Router Implementation
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -9,15 +21,222 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// Placeholder for gRPC PD router
#[derive(Debug)]
pub struct GrpcPDRouter;
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
/// Prefill worker connections
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// Decode worker connections
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for prefill workers
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// gRPC clients for decode workers
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy for prefill
prefill_policy: Arc<dyn LoadBalancingPolicy>,
/// Load balancing policy for decode
decode_policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checkers
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcPDRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC PD router initialization
Err("gRPC PD router not yet implemented".to_string())
/// Create a new gRPC PD router
#[allow(clippy::too_many_arguments)]
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for prefill workers
let mut prefill_grpc_clients = HashMap::new();
for (url, _bootstrap_port) in &prefill_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
prefill_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC prefill worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
// Continue with other workers
}
}
}
// Create gRPC clients for decode workers
let mut decode_grpc_clients = HashMap::new();
for url in &decode_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
decode_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC decode worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Prefill Worker trait objects with gRPC connection mode
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.iter()
.map(|(url, bootstrap_port)| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
},
crate::core::ConnectionMode::Grpc {
port: *bootstrap_port,
},
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs,
endpoint: health_check_config.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold,
success_threshold: health_check_config.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Create Decode Worker trait objects with gRPC connection mode
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.iter()
.map(|url| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Decode,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs,
endpoint: health_check_config.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold,
success_threshold: health_check_config.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Initialize policies with workers if needed
if let Some(cache_aware) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&prefill_workers);
}
if let Some(cache_aware) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&decode_workers);
}
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
let prefill_health_checker =
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
Ok(GrpcPDRouter {
prefill_workers,
decode_workers,
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
prefill_policy,
decode_policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
timeout_secs,
interval_secs,
dp_aware,
api_key,
retry_config,
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcPDRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcPDRouter")
.field(
"prefill_workers_count",
&self.prefill_workers.read().unwrap().len(),
)
.field(
"decode_workers_count",
&self.decode_workers.read().unwrap().len(),
)
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}

View File

@@ -1,7 +1,19 @@
// gRPC Router Implementation
// TODO: Implement gRPC-based router
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -9,15 +21,150 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// Placeholder for gRPC router
#[derive(Debug)]
pub struct GrpcRouter;
/// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter {
/// Worker connections
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for each worker
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checker
_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC router initialization
Err("gRPC router not yet implemented".to_string())
/// Create a new gRPC router
#[allow(clippy::too_many_arguments)]
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(worker_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for each worker
let mut grpc_clients = HashMap::new();
for url in &worker_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Worker trait objects with gRPC connection mode
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Regular,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs,
endpoint: health_check_config.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold,
success_threshold: health_check_config.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Initialize policy with workers if needed
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
Ok(GrpcRouter {
workers,
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_health_checker: Some(health_checker),
timeout_secs,
interval_secs,
dp_aware,
api_key,
retry_config,
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcRouter")
.field("workers_count", &self.workers.read().unwrap().len())
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}