[router] refactor router and worker management 2/n (#10666)

This commit is contained in:
Simo Lin
2025-09-19 15:37:57 -04:00
committed by GitHub
parent dab4663b4e
commit 00eb5eb721
11 changed files with 483 additions and 157 deletions

View File

@@ -3,7 +3,7 @@
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
@@ -371,12 +371,30 @@ impl PDRouter {
}
}
#[allow(clippy::too_many_arguments)]
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
let prefill_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(ConnectionMode::Http),
false, // include all workers
);
let decode_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(ConnectionMode::Http),
false, // include all workers
);
// Get all worker URLs for monitoring
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.url().to_string())
.collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
@@ -386,60 +404,6 @@ impl PDRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Register prefill workers in the registry
for (url, port) in prefill_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill {
bootstrap_port: port,
})
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Register decode workers in the registry
for url in decode_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Get all workers from registry for health check
let all_workers = ctx.worker_registry.get_all();
let all_urls: Vec<String> = all_workers
.iter()
.map(|worker| worker.url().to_string())
.collect();
if !all_urls.is_empty() {
crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
// Initialize cache-aware policies with workers from registry
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
@@ -471,11 +435,8 @@ impl PDRouter {
None
};
// Note: Health checking is now handled centrally by RouterManager
// Individual routers no longer need to manage health checkers
// Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder()
let prefill_client = Client::builder()
.pool_max_idle_per_host(0)
.http1_only()
.connect_timeout(Duration::from_millis(300))
@@ -489,6 +450,7 @@ impl PDRouter {
// Spawn a coordinator with limited concurrent drain tasks
// This prevents unbounded task spawning under extreme load
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
tokio::spawn(async move {
info!("Prefill drain coordinator started");
@@ -513,7 +475,7 @@ impl PDRouter {
// Drain the response body efficiently
// Use streaming to avoid loading entire body into memory
let start = std::time::Instant::now();
let start = Instant::now();
let mut stream = response.bytes_stream();
let mut bytes_drained = 0;

View File

@@ -1,6 +1,6 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
@@ -47,31 +47,19 @@ pub struct Router {
impl Router {
/// Create a new router with injected policy and client
#[allow(clippy::too_many_arguments)]
pub async fn new(
worker_urls: Vec<String>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
let workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false, // include all workers
);
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
RouterMetrics::set_active_workers(workers.len());
// Wait for workers to be healthy (skip if empty - for service discovery mode)
if !worker_urls.is_empty() {
Self::wait_for_healthy_workers(
&worker_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
let worker_urls = if ctx.router_config.dp_aware {
// worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Get worker URLs for monitoring
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
@@ -82,40 +70,14 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Register workers in the registry
// In IGW mode, we need to fetch model info from workers
for url in &worker_urls {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = ctx.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
}
// Initialize cache-aware policy with workers if needed
let default_policy = ctx.policy_registry.get_default_policy();
if default_policy.name() == "cache_aware" {
if let Some(cache_aware) = default_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
}
@@ -124,7 +86,6 @@ impl Router {
let worker_loads = Arc::new(rx);
// Check if default policy is power_of_two for load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;