[router] refactor router and worker management 2.5/n (#10677)

This commit is contained in:
Simo Lin
2025-09-19 23:54:40 -04:00
committed by GitHub
parent 60e2a7cead
commit 1d1ce62495
8 changed files with 235 additions and 123 deletions

View File

@@ -9,6 +9,7 @@ use super::{
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
@@ -255,6 +256,81 @@ impl PolicyRegistry {
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
/// Initialize cache-aware policy with workers if applicable
/// This should be called after workers are registered for a model
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
debug!(
"Initializing cache-aware policy with {} workers for model {}",
workers.len(),
model_id
);
cache_aware.init_workers(workers);
}
}
}
}
/// Remove a worker from cache-aware policy if applicable
/// This should be called when a worker is being removed
pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
cache_aware.remove_worker_by_url(worker_url);
debug!(
"Removed worker {} from cache-aware policy for model {}",
worker_url, model_id
);
}
}
}
}
/// Initialize cache-aware policies for PD mode (prefill and decode)
pub fn init_pd_cache_aware_policies(
&self,
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
) {
// Initialize prefill policy if it's cache-aware
if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() {
if prefill_policy.name() == "cache_aware" {
if let Some(cache_aware) =
prefill_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !prefill_workers.is_empty() {
debug!(
"Initializing prefill cache-aware policy with {} workers",
prefill_workers.len()
);
cache_aware.init_workers(prefill_workers);
}
}
}
}
// Initialize decode policy if it's cache-aware
if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() {
if decode_policy.name() == "cache_aware" {
if let Some(cache_aware) = decode_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !decode_workers.is_empty() {
debug!(
"Initializing decode cache-aware policy with {} workers",
decode_workers.len()
);
cache_aware.init_workers(decode_workers);
}
}
}
}
}
}
impl std::fmt::Debug for PolicyRegistry {