[router] refactor router and worker management 2/n (#10666)
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
use super::pd_types::{api_path, PDRouterError};
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
|
||||
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
@@ -371,12 +371,30 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
||||
let prefill_workers = ctx.worker_registry.get_workers_filtered(
|
||||
None, // any model
|
||||
Some(WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
}),
|
||||
Some(ConnectionMode::Http),
|
||||
false, // include all workers
|
||||
);
|
||||
|
||||
let decode_workers = ctx.worker_registry.get_workers_filtered(
|
||||
None, // any model
|
||||
Some(WorkerType::Decode),
|
||||
Some(ConnectionMode::Http),
|
||||
false, // include all workers
|
||||
);
|
||||
|
||||
// Get all worker URLs for monitoring
|
||||
let all_urls: Vec<String> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
@@ -386,60 +404,6 @@ impl PDRouter {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Register prefill workers in the registry
|
||||
for (url, port) in prefill_urls {
|
||||
let worker = BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: port,
|
||||
})
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.build();
|
||||
ctx.worker_registry.register(Arc::new(worker));
|
||||
}
|
||||
|
||||
// Register decode workers in the registry
|
||||
for url in decode_urls {
|
||||
let worker = BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Decode)
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.build();
|
||||
ctx.worker_registry.register(Arc::new(worker));
|
||||
}
|
||||
|
||||
// Get all workers from registry for health check
|
||||
let all_workers = ctx.worker_registry.get_all();
|
||||
let all_urls: Vec<String> = all_workers
|
||||
.iter()
|
||||
.map(|worker| worker.url().to_string())
|
||||
.collect();
|
||||
if !all_urls.is_empty() {
|
||||
crate::routers::http::router::Router::wait_for_healthy_workers(
|
||||
&all_urls,
|
||||
ctx.router_config.worker_startup_timeout_secs,
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies with workers from registry
|
||||
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
|
||||
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
|
||||
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
|
||||
|
||||
// Set up background load monitoring for power-of-two selection
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
@@ -471,11 +435,8 @@ impl PDRouter {
|
||||
None
|
||||
};
|
||||
|
||||
// Note: Health checking is now handled centrally by RouterManager
|
||||
// Individual routers no longer need to manage health checkers
|
||||
|
||||
// Build a dedicated prefill client for fire-and-forget semantics
|
||||
let prefill_client = reqwest::Client::builder()
|
||||
let prefill_client = Client::builder()
|
||||
.pool_max_idle_per_host(0)
|
||||
.http1_only()
|
||||
.connect_timeout(Duration::from_millis(300))
|
||||
@@ -489,6 +450,7 @@ impl PDRouter {
|
||||
|
||||
// Spawn a coordinator with limited concurrent drain tasks
|
||||
// This prevents unbounded task spawning under extreme load
|
||||
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
|
||||
tokio::spawn(async move {
|
||||
info!("Prefill drain coordinator started");
|
||||
|
||||
@@ -513,7 +475,7 @@ impl PDRouter {
|
||||
|
||||
// Drain the response body efficiently
|
||||
// Use streaming to avoid loading entire body into memory
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut bytes_drained = 0;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
|
||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
|
||||
Worker, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
@@ -47,31 +47,19 @@ pub struct Router {
|
||||
|
||||
impl Router {
|
||||
/// Create a new router with injected policy and client
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
||||
let workers = ctx.worker_registry.get_workers_filtered(
|
||||
None, // any model
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Http),
|
||||
false, // include all workers
|
||||
);
|
||||
|
||||
// Update active workers gauge
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
RouterMetrics::set_active_workers(workers.len());
|
||||
|
||||
// Wait for workers to be healthy (skip if empty - for service discovery mode)
|
||||
if !worker_urls.is_empty() {
|
||||
Self::wait_for_healthy_workers(
|
||||
&worker_urls,
|
||||
ctx.router_config.worker_startup_timeout_secs,
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let worker_urls = if ctx.router_config.dp_aware {
|
||||
// worker address now in the format of "http://host:port@dp_rank"
|
||||
Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
|
||||
} else {
|
||||
worker_urls
|
||||
};
|
||||
// Get worker URLs for monitoring
|
||||
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
@@ -82,40 +70,14 @@ impl Router {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Register workers in the registry
|
||||
// In IGW mode, we need to fetch model info from workers
|
||||
for url in &worker_urls {
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
// For now, create worker without model_id
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.build();
|
||||
|
||||
let worker_arc = Arc::new(worker);
|
||||
ctx.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = ctx.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy and it's the first worker for this model,
|
||||
// initialize it with the worker
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
|
||||
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
|
||||
}
|
||||
// Initialize cache-aware policy with workers if needed
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
if default_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = default_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +86,6 @@ impl Router {
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
|
||||
Reference in New Issue
Block a user