[router] refactor router and worker management 2/n (#10666)
This commit is contained in:
@@ -47,18 +47,17 @@ impl RouterFactory {
|
|||||||
ConnectionMode::Http => {
|
ConnectionMode::Http => {
|
||||||
// Route to HTTP implementation based on routing mode
|
// Route to HTTP implementation based on routing mode
|
||||||
match &ctx.router_config.mode {
|
match &ctx.router_config.mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { .. } => {
|
||||||
Self::create_regular_router(worker_urls, ctx).await
|
// Workers already initialized in registry
|
||||||
|
Self::create_regular_router(ctx).await
|
||||||
}
|
}
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
|
||||||
decode_urls,
|
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
|
// Workers already initialized in registry
|
||||||
Self::create_pd_router(
|
Self::create_pd_router(
|
||||||
prefill_urls,
|
|
||||||
decode_urls,
|
|
||||||
prefill_policy.as_ref(),
|
prefill_policy.as_ref(),
|
||||||
decode_policy.as_ref(),
|
decode_policy.as_ref(),
|
||||||
&ctx.router_config.policy,
|
&ctx.router_config.policy,
|
||||||
@@ -76,19 +75,17 @@ impl RouterFactory {
|
|||||||
|
|
||||||
/// Create a regular router
|
/// Create a regular router
|
||||||
pub async fn create_regular_router(
|
pub async fn create_regular_router(
|
||||||
worker_urls: &[String],
|
|
||||||
ctx: &Arc<AppContext>,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// Create regular router with context
|
// Create regular router with context
|
||||||
let router = Router::new(worker_urls.to_vec(), ctx).await?;
|
// Workers should already be initialized in the registry
|
||||||
|
let router = Router::new(ctx).await?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a PD router with injected policy
|
/// Create a PD router with injected policy
|
||||||
pub async fn create_pd_router(
|
pub async fn create_pd_router(
|
||||||
prefill_urls: &[(String, Option<u16>)],
|
|
||||||
decode_urls: &[String],
|
|
||||||
prefill_policy_config: Option<&PolicyConfig>,
|
prefill_policy_config: Option<&PolicyConfig>,
|
||||||
decode_policy_config: Option<&PolicyConfig>,
|
decode_policy_config: Option<&PolicyConfig>,
|
||||||
main_policy_config: &PolicyConfig,
|
main_policy_config: &PolicyConfig,
|
||||||
@@ -105,7 +102,8 @@ impl RouterFactory {
|
|||||||
ctx.policy_registry.set_decode_policy(decode_policy);
|
ctx.policy_registry.set_decode_policy(decode_policy);
|
||||||
|
|
||||||
// Create PD router with context (policies are in PolicyRegistry)
|
// Create PD router with context (policies are in PolicyRegistry)
|
||||||
let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?;
|
// Workers should already be initialized in the registry
|
||||||
|
let router = PDRouter::new(ctx).await?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
|
||||||
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
@@ -371,12 +371,30 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
||||||
pub async fn new(
|
let prefill_workers = ctx.worker_registry.get_workers_filtered(
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
None, // any model
|
||||||
decode_urls: Vec<String>,
|
Some(WorkerType::Prefill {
|
||||||
ctx: &Arc<crate::server::AppContext>,
|
bootstrap_port: None,
|
||||||
) -> Result<Self, String> {
|
}),
|
||||||
|
Some(ConnectionMode::Http),
|
||||||
|
false, // include all workers
|
||||||
|
);
|
||||||
|
|
||||||
|
let decode_workers = ctx.worker_registry.get_workers_filtered(
|
||||||
|
None, // any model
|
||||||
|
Some(WorkerType::Decode),
|
||||||
|
Some(ConnectionMode::Http),
|
||||||
|
false, // include all workers
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get all worker URLs for monitoring
|
||||||
|
let all_urls: Vec<String> = prefill_workers
|
||||||
|
.iter()
|
||||||
|
.chain(decode_workers.iter())
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||||
let core_cb_config = CircuitBreakerConfig {
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
@@ -386,60 +404,6 @@ impl PDRouter {
|
|||||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register prefill workers in the registry
|
|
||||||
for (url, port) in prefill_urls {
|
|
||||||
let worker = BasicWorkerBuilder::new(url)
|
|
||||||
.worker_type(WorkerType::Prefill {
|
|
||||||
bootstrap_port: port,
|
|
||||||
})
|
|
||||||
.circuit_breaker_config(core_cb_config.clone())
|
|
||||||
.health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
|
||||||
})
|
|
||||||
.build();
|
|
||||||
ctx.worker_registry.register(Arc::new(worker));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register decode workers in the registry
|
|
||||||
for url in decode_urls {
|
|
||||||
let worker = BasicWorkerBuilder::new(url)
|
|
||||||
.worker_type(WorkerType::Decode)
|
|
||||||
.circuit_breaker_config(core_cb_config.clone())
|
|
||||||
.health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
|
||||||
})
|
|
||||||
.build();
|
|
||||||
ctx.worker_registry.register(Arc::new(worker));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all workers from registry for health check
|
|
||||||
let all_workers = ctx.worker_registry.get_all();
|
|
||||||
let all_urls: Vec<String> = all_workers
|
|
||||||
.iter()
|
|
||||||
.map(|worker| worker.url().to_string())
|
|
||||||
.collect();
|
|
||||||
if !all_urls.is_empty() {
|
|
||||||
crate::routers::http::router::Router::wait_for_healthy_workers(
|
|
||||||
&all_urls,
|
|
||||||
ctx.router_config.worker_startup_timeout_secs,
|
|
||||||
ctx.router_config.worker_startup_check_interval_secs,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize cache-aware policies with workers from registry
|
|
||||||
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
|
|
||||||
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
|
|
||||||
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
|
|
||||||
|
|
||||||
// Set up background load monitoring for power-of-two selection
|
// Set up background load monitoring for power-of-two selection
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
let worker_loads = Arc::new(rx);
|
let worker_loads = Arc::new(rx);
|
||||||
@@ -471,11 +435,8 @@ impl PDRouter {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Note: Health checking is now handled centrally by RouterManager
|
|
||||||
// Individual routers no longer need to manage health checkers
|
|
||||||
|
|
||||||
// Build a dedicated prefill client for fire-and-forget semantics
|
// Build a dedicated prefill client for fire-and-forget semantics
|
||||||
let prefill_client = reqwest::Client::builder()
|
let prefill_client = Client::builder()
|
||||||
.pool_max_idle_per_host(0)
|
.pool_max_idle_per_host(0)
|
||||||
.http1_only()
|
.http1_only()
|
||||||
.connect_timeout(Duration::from_millis(300))
|
.connect_timeout(Duration::from_millis(300))
|
||||||
@@ -489,6 +450,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Spawn a coordinator with limited concurrent drain tasks
|
// Spawn a coordinator with limited concurrent drain tasks
|
||||||
// This prevents unbounded task spawning under extreme load
|
// This prevents unbounded task spawning under extreme load
|
||||||
|
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
info!("Prefill drain coordinator started");
|
info!("Prefill drain coordinator started");
|
||||||
|
|
||||||
@@ -513,7 +475,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Drain the response body efficiently
|
// Drain the response body efficiently
|
||||||
// Use streaming to avoid loading entire body into memory
|
// Use streaming to avoid loading entire body into memory
|
||||||
let start = std::time::Instant::now();
|
let start = Instant::now();
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
let mut bytes_drained = 0;
|
let mut bytes_drained = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
|
||||||
Worker, WorkerRegistry, WorkerType,
|
Worker, WorkerRegistry, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
@@ -47,31 +47,19 @@ pub struct Router {
|
|||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
/// Create a new router with injected policy and client
|
/// Create a new router with injected policy and client
|
||||||
#[allow(clippy::too_many_arguments)]
|
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
||||||
pub async fn new(
|
let workers = ctx.worker_registry.get_workers_filtered(
|
||||||
worker_urls: Vec<String>,
|
None, // any model
|
||||||
ctx: &Arc<crate::server::AppContext>,
|
Some(WorkerType::Regular),
|
||||||
) -> Result<Self, String> {
|
Some(ConnectionMode::Http),
|
||||||
|
false, // include all workers
|
||||||
|
);
|
||||||
|
|
||||||
// Update active workers gauge
|
// Update active workers gauge
|
||||||
RouterMetrics::set_active_workers(worker_urls.len());
|
RouterMetrics::set_active_workers(workers.len());
|
||||||
|
|
||||||
// Wait for workers to be healthy (skip if empty - for service discovery mode)
|
// Get worker URLs for monitoring
|
||||||
if !worker_urls.is_empty() {
|
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
||||||
Self::wait_for_healthy_workers(
|
|
||||||
&worker_urls,
|
|
||||||
ctx.router_config.worker_startup_timeout_secs,
|
|
||||||
ctx.router_config.worker_startup_check_interval_secs,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let worker_urls = if ctx.router_config.dp_aware {
|
|
||||||
// worker address now in the format of "http://host:port@dp_rank"
|
|
||||||
Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
|
|
||||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
|
|
||||||
} else {
|
|
||||||
worker_urls
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||||
@@ -82,40 +70,14 @@ impl Router {
|
|||||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register workers in the registry
|
// Initialize cache-aware policy with workers if needed
|
||||||
// In IGW mode, we need to fetch model info from workers
|
let default_policy = ctx.policy_registry.get_default_policy();
|
||||||
for url in &worker_urls {
|
if default_policy.name() == "cache_aware" {
|
||||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
if let Some(cache_aware) = default_policy
|
||||||
// For now, create worker without model_id
|
.as_any()
|
||||||
let worker = BasicWorkerBuilder::new(url.clone())
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
.worker_type(WorkerType::Regular)
|
{
|
||||||
.circuit_breaker_config(core_cb_config.clone())
|
cache_aware.init_workers(&workers);
|
||||||
.health_config(HealthConfig {
|
|
||||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
|
||||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
|
||||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
|
||||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
|
||||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
|
||||||
})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
let worker_arc = Arc::new(worker);
|
|
||||||
ctx.worker_registry.register(worker_arc.clone());
|
|
||||||
|
|
||||||
// Notify PolicyRegistry about the new worker
|
|
||||||
let model_id = worker_arc.model_id();
|
|
||||||
let policy = ctx.policy_registry.on_worker_added(model_id, None);
|
|
||||||
|
|
||||||
// If this is a cache-aware policy and it's the first worker for this model,
|
|
||||||
// initialize it with the worker
|
|
||||||
if policy.name() == "cache_aware" {
|
|
||||||
if let Some(cache_aware) = policy
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
|
||||||
{
|
|
||||||
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
|
|
||||||
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +86,6 @@ impl Router {
|
|||||||
let worker_loads = Arc::new(rx);
|
let worker_loads = Arc::new(rx);
|
||||||
|
|
||||||
// Check if default policy is power_of_two for load monitoring
|
// Check if default policy is power_of_two for load monitoring
|
||||||
let default_policy = ctx.policy_registry.get_default_policy();
|
|
||||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||||
let monitor_urls = worker_urls.clone();
|
let monitor_urls = worker_urls.clone();
|
||||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ pub mod grpc;
|
|||||||
pub mod header_utils;
|
pub mod header_utils;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod router_manager;
|
pub mod router_manager;
|
||||||
|
pub mod worker_initializer;
|
||||||
|
|
||||||
pub use factory::RouterFactory;
|
pub use factory::RouterFactory;
|
||||||
|
pub use worker_initializer::WorkerInitializer;
|
||||||
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||||
pub use http::{openai_router, pd_router, pd_types, router};
|
pub use http::{openai_router, pd_router, pd_types, router};
|
||||||
|
|
||||||
|
|||||||
361
sgl-router/src/routers/worker_initializer.rs
Normal file
361
sgl-router/src/routers/worker_initializer.rs
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
// Worker Initialization Module
|
||||||
|
// Separates worker lifecycle management from router construction
|
||||||
|
|
||||||
|
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
|
||||||
|
use crate::core::{
|
||||||
|
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry,
|
||||||
|
WorkerType,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
/// WorkerInitializer handles the creation and registration of workers
|
||||||
|
/// based on routing configuration, separating this concern from router constructors
|
||||||
|
pub struct WorkerInitializer;
|
||||||
|
|
||||||
|
impl WorkerInitializer {
|
||||||
|
/// Initialize workers based on configuration and register them in the WorkerRegistry
|
||||||
|
pub async fn initialize_workers(
|
||||||
|
config: &RouterConfig,
|
||||||
|
worker_registry: &Arc<WorkerRegistry>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!("Initializing workers for routing mode: {:?}", config.mode);
|
||||||
|
|
||||||
|
match &config.mode {
|
||||||
|
RoutingMode::Regular { worker_urls } => {
|
||||||
|
Self::create_regular_workers(
|
||||||
|
worker_urls,
|
||||||
|
&config.connection_mode,
|
||||||
|
config,
|
||||||
|
worker_registry,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
Self::create_prefill_workers(
|
||||||
|
prefill_urls,
|
||||||
|
&config.connection_mode,
|
||||||
|
config,
|
||||||
|
worker_registry,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Self::create_decode_workers(
|
||||||
|
decode_urls,
|
||||||
|
&config.connection_mode,
|
||||||
|
config,
|
||||||
|
worker_registry,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
RoutingMode::OpenAI { .. } => {
|
||||||
|
info!("OpenAI routing mode - no local workers to initialize");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for workers to be healthy if any were registered
|
||||||
|
if worker_registry.stats().total_workers > 0 {
|
||||||
|
Self::wait_for_healthy_workers(
|
||||||
|
worker_registry,
|
||||||
|
config.worker_startup_timeout_secs,
|
||||||
|
config.worker_startup_check_interval_secs,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create regular workers for standard routing mode
|
||||||
|
async fn create_regular_workers(
|
||||||
|
urls: &[String],
|
||||||
|
config_connection_mode: &ConfigConnectionMode,
|
||||||
|
config: &RouterConfig,
|
||||||
|
registry: &Arc<WorkerRegistry>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!("Creating {} regular workers", urls.len());
|
||||||
|
|
||||||
|
// Convert config connection mode to core connection mode
|
||||||
|
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
|
||||||
|
|
||||||
|
// Convert circuit breaker config
|
||||||
|
let circuit_breaker_config = config.effective_circuit_breaker_config();
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert health check config
|
||||||
|
let health_config = HealthConfig {
|
||||||
|
timeout_secs: config.health_check.timeout_secs,
|
||||||
|
check_interval_secs: config.health_check.check_interval_secs,
|
||||||
|
endpoint: config.health_check.endpoint.clone(),
|
||||||
|
failure_threshold: config.health_check.failure_threshold,
|
||||||
|
success_threshold: config.health_check.success_threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
for url in urls {
|
||||||
|
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||||
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
|
.worker_type(WorkerType::Regular)
|
||||||
|
.connection_mode(connection_mode.clone())
|
||||||
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(health_config.clone())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let worker_id = registry.register(Arc::new(worker));
|
||||||
|
info!("Registered regular worker {} with ID {:?}", url, worker_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create prefill workers for disaggregated routing mode
|
||||||
|
async fn create_prefill_workers(
|
||||||
|
prefill_entries: &[(String, Option<u16>)],
|
||||||
|
config_connection_mode: &ConfigConnectionMode,
|
||||||
|
config: &RouterConfig,
|
||||||
|
registry: &Arc<WorkerRegistry>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!("Creating {} prefill workers", prefill_entries.len());
|
||||||
|
|
||||||
|
// Convert config connection mode to core connection mode
|
||||||
|
let connection_mode = Self::convert_connection_mode(
|
||||||
|
config_connection_mode,
|
||||||
|
prefill_entries.first().map(|(url, _)| url),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Convert circuit breaker config
|
||||||
|
let circuit_breaker_config = config.effective_circuit_breaker_config();
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert health check config
|
||||||
|
let health_config = HealthConfig {
|
||||||
|
timeout_secs: config.health_check.timeout_secs,
|
||||||
|
check_interval_secs: config.health_check.check_interval_secs,
|
||||||
|
endpoint: config.health_check.endpoint.clone(),
|
||||||
|
failure_threshold: config.health_check.failure_threshold,
|
||||||
|
success_threshold: config.health_check.success_threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
for (url, bootstrap_port) in prefill_entries {
|
||||||
|
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||||
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
|
.worker_type(WorkerType::Prefill {
|
||||||
|
bootstrap_port: *bootstrap_port,
|
||||||
|
})
|
||||||
|
.connection_mode(connection_mode.clone())
|
||||||
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(health_config.clone())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let worker_id = registry.register(Arc::new(worker));
|
||||||
|
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create decode workers for disaggregated routing mode
|
||||||
|
async fn create_decode_workers(
|
||||||
|
urls: &[String],
|
||||||
|
config_connection_mode: &ConfigConnectionMode,
|
||||||
|
config: &RouterConfig,
|
||||||
|
registry: &Arc<WorkerRegistry>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!("Creating {} decode workers", urls.len());
|
||||||
|
|
||||||
|
// Convert config connection mode to core connection mode
|
||||||
|
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
|
||||||
|
|
||||||
|
// Convert circuit breaker config
|
||||||
|
let circuit_breaker_config = config.effective_circuit_breaker_config();
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert health check config
|
||||||
|
let health_config = HealthConfig {
|
||||||
|
timeout_secs: config.health_check.timeout_secs,
|
||||||
|
check_interval_secs: config.health_check.check_interval_secs,
|
||||||
|
endpoint: config.health_check.endpoint.clone(),
|
||||||
|
failure_threshold: config.health_check.failure_threshold,
|
||||||
|
success_threshold: config.health_check.success_threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
for url in urls {
|
||||||
|
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||||
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
|
.worker_type(WorkerType::Decode)
|
||||||
|
.connection_mode(connection_mode.clone())
|
||||||
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(health_config.clone())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let worker_id = registry.register(Arc::new(worker));
|
||||||
|
info!("Registered decode worker {} with ID {:?}", url, worker_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert config connection mode to core connection mode
|
||||||
|
fn convert_connection_mode(
|
||||||
|
config_mode: &ConfigConnectionMode,
|
||||||
|
_sample_url: Option<&String>,
|
||||||
|
) -> ConnectionMode {
|
||||||
|
match config_mode {
|
||||||
|
ConfigConnectionMode::Http => ConnectionMode::Http,
|
||||||
|
ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for workers to become healthy
|
||||||
|
async fn wait_for_healthy_workers(
|
||||||
|
registry: &Arc<WorkerRegistry>,
|
||||||
|
timeout_secs: u64,
|
||||||
|
check_interval_secs: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let timeout = Duration::from_secs(timeout_secs);
|
||||||
|
let check_interval = Duration::from_secs(check_interval_secs);
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Waiting for workers to become healthy (timeout: {}s)",
|
||||||
|
timeout_secs
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let stats = registry.stats();
|
||||||
|
|
||||||
|
if stats.healthy_workers > 0 {
|
||||||
|
info!(
|
||||||
|
"Workers healthy: {}/{} workers are ready",
|
||||||
|
stats.healthy_workers, stats.total_workers
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we have at least one healthy worker, we can proceed
|
||||||
|
// This allows partial degradation rather than total failure
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_time.elapsed() > timeout {
|
||||||
|
let error_msg = format!(
|
||||||
|
"Timeout waiting for workers to become healthy after {}s. Total workers: {}, Healthy: {}",
|
||||||
|
timeout_secs, stats.total_workers, stats.healthy_workers
|
||||||
|
);
|
||||||
|
warn!("{}", error_msg);
|
||||||
|
|
||||||
|
// If we have workers but none are healthy, it's still a failure
|
||||||
|
if stats.total_workers > 0 {
|
||||||
|
return Err(error_msg);
|
||||||
|
} else {
|
||||||
|
// No workers at all might be OK for some configurations
|
||||||
|
warn!("No workers registered, proceeding anyway");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(check_interval).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize workers for gRPC connections specifically
|
||||||
|
/// This is used when gRPC clients are pre-connected
|
||||||
|
pub async fn initialize_grpc_workers(
|
||||||
|
worker_urls: &[String],
|
||||||
|
worker_type: WorkerType,
|
||||||
|
config: &RouterConfig,
|
||||||
|
registry: &Arc<WorkerRegistry>,
|
||||||
|
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!(
|
||||||
|
"Creating {} gRPC workers of type {:?}",
|
||||||
|
worker_urls.len(),
|
||||||
|
worker_type
|
||||||
|
);
|
||||||
|
|
||||||
|
// Convert circuit breaker config
|
||||||
|
let circuit_breaker_config = config.effective_circuit_breaker_config();
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert health check config
|
||||||
|
let health_config = HealthConfig {
|
||||||
|
timeout_secs: config.health_check.timeout_secs,
|
||||||
|
check_interval_secs: config.health_check.check_interval_secs,
|
||||||
|
endpoint: config.health_check.endpoint.clone(),
|
||||||
|
failure_threshold: config.health_check.failure_threshold,
|
||||||
|
success_threshold: config.health_check.success_threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
for url in worker_urls {
|
||||||
|
if let Some(client) = grpc_clients.remove(url) {
|
||||||
|
let worker = BasicWorkerBuilder::new(url.clone())
|
||||||
|
.worker_type(worker_type.clone())
|
||||||
|
.connection_mode(ConnectionMode::Grpc { port: None })
|
||||||
|
.circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.health_config(health_config.clone())
|
||||||
|
.grpc_client(client)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let worker_id = registry.register(Arc::new(worker));
|
||||||
|
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
|
||||||
|
} else {
|
||||||
|
warn!("No gRPC client available for worker {}, skipping", url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_connection_mode() {
|
||||||
|
// HTTP mode
|
||||||
|
assert!(matches!(
|
||||||
|
WorkerInitializer::convert_connection_mode(
|
||||||
|
&ConfigConnectionMode::Http,
|
||||||
|
Some(&"http://localhost:8080".to_string())
|
||||||
|
),
|
||||||
|
ConnectionMode::Http
|
||||||
|
));
|
||||||
|
|
||||||
|
// gRPC mode
|
||||||
|
assert!(matches!(
|
||||||
|
WorkerInitializer::convert_connection_mode(
|
||||||
|
&ConfigConnectionMode::Grpc,
|
||||||
|
Some(&"grpc://localhost:50051".to_string())
|
||||||
|
),
|
||||||
|
ConnectionMode::Grpc { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// No URL provided
|
||||||
|
assert!(matches!(
|
||||||
|
WorkerInitializer::convert_connection_mode(&ConfigConnectionMode::Http, None),
|
||||||
|
ConnectionMode::Http
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ use crate::{
|
|||||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||||
},
|
},
|
||||||
reasoning_parser::ParserFactory,
|
reasoning_parser::ParserFactory,
|
||||||
|
routers::WorkerInitializer,
|
||||||
routers::{
|
routers::{
|
||||||
router_manager::{RouterId, RouterManager},
|
router_manager::{RouterId, RouterManager},
|
||||||
RouterFactory, RouterTrait,
|
RouterFactory, RouterTrait,
|
||||||
@@ -594,6 +595,22 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
|
|
||||||
let app_context = Arc::new(app_context);
|
let app_context = Arc::new(app_context);
|
||||||
|
|
||||||
|
// Initialize workers before creating routers
|
||||||
|
// This separates worker lifecycle from router lifecycle
|
||||||
|
info!(
|
||||||
|
"Initializing workers for routing mode: {:?}",
|
||||||
|
config.router_config.mode
|
||||||
|
);
|
||||||
|
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
|
||||||
|
|
||||||
|
let worker_stats = app_context.worker_registry.stats();
|
||||||
|
info!(
|
||||||
|
"Workers initialized: {} total, {} healthy",
|
||||||
|
worker_stats.total_workers, worker_stats.healthy_workers
|
||||||
|
);
|
||||||
|
|
||||||
// Create the appropriate router based on enable_igw flag
|
// Create the appropriate router based on enable_igw flag
|
||||||
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
|
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
|
||||||
if config.router_config.enable_igw {
|
if config.router_config.enable_igw {
|
||||||
@@ -608,12 +625,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
));
|
));
|
||||||
|
|
||||||
// 1. HTTP Regular Router
|
// 1. HTTP Regular Router
|
||||||
match RouterFactory::create_regular_router(
|
match RouterFactory::create_regular_router(&app_context).await {
|
||||||
&[], // Empty worker list - workers added later
|
|
||||||
&app_context,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(http_regular) => {
|
Ok(http_regular) => {
|
||||||
info!("Created HTTP Regular router");
|
info!("Created HTTP Regular router");
|
||||||
router_manager.register_router(
|
router_manager.register_router(
|
||||||
@@ -628,8 +640,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
|
|
||||||
// 2. HTTP PD Router
|
// 2. HTTP PD Router
|
||||||
match RouterFactory::create_pd_router(
|
match RouterFactory::create_pd_router(
|
||||||
&[],
|
|
||||||
&[],
|
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&config.router_config.policy,
|
&config.router_config.policy,
|
||||||
@@ -684,7 +694,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
|
|
||||||
// Start queue processor if enabled
|
// Start queue processor if enabled
|
||||||
if let Some(processor) = processor {
|
if let Some(processor) = processor {
|
||||||
tokio::spawn(processor.run());
|
spawn(processor.run());
|
||||||
info!(
|
info!(
|
||||||
"Started request queue with size: {}, timeout: {}s",
|
"Started request queue with size: {}, timeout: {}s",
|
||||||
config.router_config.queue_size, config.router_config.queue_timeout_secs
|
config.router_config.queue_size, config.router_config.queue_timeout_secs
|
||||||
|
|||||||
@@ -606,7 +606,7 @@ mod tests {
|
|||||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||||
});
|
});
|
||||||
|
|
||||||
let router = Router::new(vec![], &app_context).await.unwrap();
|
let router = Router::new(&app_context).await.unwrap();
|
||||||
Arc::new(router) as Arc<dyn RouterTrait>
|
Arc::new(router) as Arc<dyn RouterTrait>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,6 +101,14 @@ impl TestContext {
|
|||||||
// Create app context
|
// Create app context
|
||||||
let app_context = common::create_test_context(config.clone());
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
|
// Initialize workers in the registry before creating router
|
||||||
|
if !worker_urls.is_empty() {
|
||||||
|
use sglang_router_rs::routers::WorkerInitializer;
|
||||||
|
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize workers");
|
||||||
|
}
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|||||||
@@ -39,9 +39,20 @@ impl TestContext {
|
|||||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular {
|
||||||
|
worker_urls: worker_urls.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
|
// Initialize workers in the registry before creating router
|
||||||
|
if !worker_urls.is_empty() {
|
||||||
|
use sglang_router_rs::routers::WorkerInitializer;
|
||||||
|
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize workers");
|
||||||
|
}
|
||||||
|
|
||||||
let app_context = common::create_test_context(config);
|
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
|
|||||||
@@ -40,9 +40,20 @@ impl TestContext {
|
|||||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular {
|
||||||
|
worker_urls: worker_urls.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
|
// Initialize workers in the registry before creating router
|
||||||
|
if !worker_urls.is_empty() {
|
||||||
|
use sglang_router_rs::routers::WorkerInitializer;
|
||||||
|
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize workers");
|
||||||
|
}
|
||||||
|
|
||||||
let app_context = common::create_test_context(config);
|
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
|
|||||||
@@ -207,19 +207,21 @@ mod test_pd_routing {
|
|||||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
|
||||||
let app_context =
|
let app_context =
|
||||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
|
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
|
||||||
.expect("Failed to create AppContext");
|
.expect("Failed to create AppContext");
|
||||||
let app_context = std::sync::Arc::new(app_context);
|
let app_context = std::sync::Arc::new(app_context);
|
||||||
let result = RouterFactory::create_router(&app_context).await;
|
let result = RouterFactory::create_router(&app_context).await;
|
||||||
assert!(result.is_err());
|
|
||||||
let error_msg = result.unwrap_err();
|
|
||||||
// Error should be about health/timeout, not configuration
|
|
||||||
assert!(
|
assert!(
|
||||||
error_msg.contains("healthy") || error_msg.contains("timeout"),
|
result.is_ok(),
|
||||||
"Unexpected error: {}",
|
"Router creation should succeed with empty worker"
|
||||||
error_msg
|
);
|
||||||
|
|
||||||
|
// Verify that no workers are registered since we didn't initialize them
|
||||||
|
let stats = app_context.worker_registry.stats();
|
||||||
|
assert_eq!(
|
||||||
|
stats.total_workers, 0,
|
||||||
|
"No workers should be registered without initialization"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user