[router] refactor router and worker management 2.5/n (#10677)

This commit is contained in:
Simo Lin
2025-09-19 23:54:40 -04:00
committed by GitHub
parent 60e2a7cead
commit 1d1ce62495
8 changed files with 235 additions and 123 deletions

View File

@@ -232,18 +232,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
info!("Added prefill server: {}", url);
Ok(format!("Successfully added prefill server: {}", url))
@@ -272,18 +266,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
@@ -307,17 +295,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
// Remove from cache-aware policy if applicable
self.policy_registry
.remove_worker_from_cache_aware(&model_id, url);
}
if removed.is_some() {
@@ -348,17 +328,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
// Remove from cache-aware policy if applicable
self.policy_registry
.remove_worker_from_cache_aware(&model_id, url);
}
if removed.is_some() {
@@ -2226,15 +2198,6 @@ mod tests {
assert_eq!(prefill_workers.len(), 1);
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
// TODO: Add new tests for the optimized bootstrap injection approach using
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
// ============= Worker Selection Tests =============
#[tokio::test]
async fn test_select_healthy_prefill_worker() {
let router = create_test_pd_router();

View File

@@ -70,21 +70,15 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Initialize cache-aware policy with workers if needed
let default_policy = ctx.policy_registry.get_default_policy();
if default_policy.name() == "cache_aware" {
if let Some(cache_aware) = default_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
}
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Get default policy to check if we need load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
@@ -964,19 +958,13 @@ impl Router {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
worker_added = true;
}
@@ -1000,20 +988,12 @@ impl Router {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, add this worker to it
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
// Get all workers for this model
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
@@ -1084,20 +1064,11 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id();
if let Some(policy) = self.policy_registry.get_policy(model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(dp_url);
info!("Removed worker from cache-aware tree: {}", dp_url);
}
}
self.policy_registry
.remove_worker_from_cache_aware(model_id, dp_url);
}
}
} else {
@@ -1118,16 +1089,8 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
}
// If the model is using cache aware policy, remove the worker from the tree
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(worker_url);
info!("Removed worker from cache-aware tree: {}", worker_url);
}
}
self.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
}
}