[router] add grpc pd and regular router init (#9893)
This commit is contained in:
@@ -4,7 +4,7 @@ use super::{
|
||||
http::{pd_router::PDRouter, router::Router},
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::config::{PolicyConfig, RoutingMode};
|
||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
@@ -20,28 +20,56 @@ impl RouterFactory {
|
||||
return Self::create_igw_router(ctx).await;
|
||||
}
|
||||
|
||||
// TODO: Add gRPC mode check here when implementing gRPC support
|
||||
|
||||
// Default to HTTP proxy mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
// Check connection mode and route to appropriate implementation
|
||||
match ctx.router_config.connection_mode {
|
||||
ConnectionMode::Grpc => {
|
||||
// Route to gRPC implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_grpc_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
ConnectionMode::Http => {
|
||||
// Route to HTTP implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||
.await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -109,25 +137,92 @@ impl RouterFactory {
|
||||
|
||||
/// Create a gRPC router with injected policy
|
||||
pub async fn create_grpc_router(
|
||||
_worker_urls: &[String],
|
||||
_policy_config: &PolicyConfig,
|
||||
_ctx: &Arc<AppContext>,
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error as gRPC router is not yet implemented
|
||||
Err("gRPC router is not yet implemented".to_string())
|
||||
use super::grpc::router::GrpcRouter;
|
||||
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Determine which tokenizer path to use
|
||||
// Priority: tokenizer_path > model_path
|
||||
let tokenizer_path = ctx
|
||||
.router_config
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.or_else(|| ctx.router_config.model_path.clone())
|
||||
.ok_or_else(|| {
|
||||
"gRPC router requires either --tokenizer-path or --model-path to be specified"
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
// Create gRPC router
|
||||
let router = GrpcRouter::new(
|
||||
worker_urls.to_vec(),
|
||||
policy,
|
||||
ctx.router_config.worker_startup_timeout_secs,
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
ctx.router_config.dp_aware,
|
||||
ctx.router_config.api_key.clone(),
|
||||
ctx.router_config.effective_retry_config(),
|
||||
ctx.router_config.effective_circuit_breaker_config(),
|
||||
ctx.router_config.health_check.clone(),
|
||||
tokenizer_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a gRPC PD router (placeholder for now)
|
||||
/// Create a gRPC PD router with tokenizer and worker configuration
|
||||
pub async fn create_grpc_pd_router(
|
||||
_prefill_urls: &[(String, Option<u16>)],
|
||||
_decode_urls: &[String],
|
||||
_prefill_policy_config: Option<&PolicyConfig>,
|
||||
_decode_policy_config: Option<&PolicyConfig>,
|
||||
_main_policy_config: &PolicyConfig,
|
||||
_ctx: &Arc<AppContext>,
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
decode_policy_config: Option<&PolicyConfig>,
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error as gRPC PD router is not yet implemented
|
||||
Err("gRPC PD router is not yet implemented".to_string())
|
||||
use super::grpc::pd_router::GrpcPDRouter;
|
||||
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Determine which tokenizer path to use
|
||||
// Priority: tokenizer_path > model_path
|
||||
let tokenizer_path = ctx
|
||||
.router_config
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.or_else(|| ctx.router_config.model_path.clone())
|
||||
.ok_or_else(|| {
|
||||
"gRPC PD router requires either --tokenizer-path or --model-path to be specified"
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
// Create gRPC PD router
|
||||
let router = GrpcPDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx.router_config.worker_startup_timeout_secs,
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
ctx.router_config.dp_aware,
|
||||
ctx.router_config.api_key.clone(),
|
||||
ctx.router_config.effective_retry_config(),
|
||||
ctx.router_config.effective_circuit_breaker_config(),
|
||||
ctx.router_config.health_check.clone(),
|
||||
tokenizer_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an IGW router (placeholder for future implementation)
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
// PD (Prefill-Decode) gRPC Router Implementation
|
||||
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
|
||||
|
||||
use crate::config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||
};
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::{factory, traits::Tokenizer};
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -9,15 +21,222 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Placeholder for gRPC PD router
|
||||
#[derive(Debug)]
|
||||
pub struct GrpcPDRouter;
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcPDRouter {
|
||||
/// Prefill worker connections
|
||||
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// Decode worker connections
|
||||
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for prefill workers
|
||||
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// gRPC clients for decode workers
|
||||
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy for prefill
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Load balancing policy for decode
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checkers
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcPDRouter {
|
||||
pub async fn new() -> Result<Self, String> {
|
||||
// TODO: Implement gRPC PD router initialization
|
||||
Err("gRPC PD router not yet implemented".to_string())
|
||||
/// Create a new gRPC PD router
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||
health_check_config: ConfigHealthCheckConfig,
|
||||
tokenizer_path_or_model: String,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
|
||||
|
||||
// Initialize tokenizer
|
||||
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
|
||||
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
|
||||
|
||||
// Initialize reasoning parser factory
|
||||
let reasoning_parser_factory = ParserFactory::new();
|
||||
|
||||
// Get tool parser registry
|
||||
let tool_parser_registry = ParserRegistry::new();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for prefill workers
|
||||
let mut prefill_grpc_clients = HashMap::new();
|
||||
for (url, _bootstrap_port) in &prefill_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
prefill_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC prefill worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create gRPC clients for decode workers
|
||||
let mut decode_grpc_clients = HashMap::new();
|
||||
for url in &decode_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
decode_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC decode worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Prefill Worker trait objects with gRPC connection mode
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.iter()
|
||||
.map(|(url, bootstrap_port)| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
},
|
||||
crate::core::ConnectionMode::Grpc {
|
||||
port: *bootstrap_port,
|
||||
},
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: health_check_config.timeout_secs,
|
||||
check_interval_secs: health_check_config.check_interval_secs,
|
||||
endpoint: health_check_config.endpoint.clone(),
|
||||
failure_threshold: health_check_config.failure_threshold,
|
||||
success_threshold: health_check_config.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create Decode Worker trait objects with gRPC connection mode
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Decode,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: health_check_config.timeout_secs,
|
||||
check_interval_secs: health_check_config.check_interval_secs,
|
||||
endpoint: health_check_config.endpoint.clone(),
|
||||
failure_threshold: health_check_config.failure_threshold,
|
||||
success_threshold: health_check_config.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initialize policies with workers if needed
|
||||
if let Some(cache_aware) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&prefill_workers);
|
||||
}
|
||||
|
||||
if let Some(cache_aware) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
|
||||
let prefill_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
|
||||
let decode_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||
|
||||
Ok(GrpcPDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
|
||||
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
api_key,
|
||||
retry_config,
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcPDRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcPDRouter")
|
||||
.field(
|
||||
"prefill_workers_count",
|
||||
&self.prefill_workers.read().unwrap().len(),
|
||||
)
|
||||
.field(
|
||||
"decode_workers_count",
|
||||
&self.decode_workers.read().unwrap().len(),
|
||||
)
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
// gRPC Router Implementation
|
||||
// TODO: Implement gRPC-based router
|
||||
|
||||
use crate::config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||
};
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::{factory, traits::Tokenizer};
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -9,15 +21,150 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Placeholder for gRPC router
|
||||
#[derive(Debug)]
|
||||
pub struct GrpcRouter;
|
||||
/// gRPC router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcRouter {
|
||||
/// Worker connections
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for each worker
|
||||
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checker
|
||||
_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
pub async fn new() -> Result<Self, String> {
|
||||
// TODO: Implement gRPC router initialization
|
||||
Err("gRPC router not yet implemented".to_string())
|
||||
/// Create a new gRPC router
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||
health_check_config: ConfigHealthCheckConfig,
|
||||
tokenizer_path_or_model: String,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
|
||||
// Initialize tokenizer
|
||||
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
|
||||
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
|
||||
|
||||
// Initialize reasoning parser factory
|
||||
let reasoning_parser_factory = ParserFactory::new();
|
||||
|
||||
// Get tool parser registry
|
||||
let tool_parser_registry = ParserRegistry::new();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for each worker
|
||||
let mut grpc_clients = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Regular,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: health_check_config.timeout_secs,
|
||||
check_interval_secs: health_check_config.check_interval_secs,
|
||||
endpoint: health_check_config.endpoint.clone(),
|
||||
failure_threshold: health_check_config.failure_threshold,
|
||||
success_threshold: health_check_config.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initialize policy with workers if needed
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||
|
||||
Ok(GrpcRouter {
|
||||
workers,
|
||||
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
|
||||
policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_health_checker: Some(health_checker),
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
api_key,
|
||||
retry_config,
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcRouter")
|
||||
.field("workers_count", &self.workers.read().unwrap().len())
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user