[router] refactor router and worker management 2.5/n (#10677)
This commit is contained in:
@@ -232,18 +232,12 @@ impl PDRouter {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
info!("Added prefill server: {}", url);
|
||||
Ok(format!("Successfully added prefill server: {}", url))
|
||||
@@ -272,18 +266,12 @@ impl PDRouter {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
@@ -307,17 +295,9 @@ impl PDRouter {
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
// Get the policy for this model to update cache-aware if needed
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove from cache-aware policy if applicable
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, url);
|
||||
}
|
||||
|
||||
if removed.is_some() {
|
||||
@@ -348,17 +328,9 @@ impl PDRouter {
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
// Get the policy for this model to update cache-aware if needed
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove from cache-aware policy if applicable
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, url);
|
||||
}
|
||||
|
||||
if removed.is_some() {
|
||||
@@ -2226,15 +2198,6 @@ mod tests {
|
||||
assert_eq!(prefill_workers.len(), 1);
|
||||
}
|
||||
|
||||
// ============= Bootstrap Injection Tests =============
|
||||
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
||||
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
||||
|
||||
// TODO: Add new tests for the optimized bootstrap injection approach using
|
||||
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
|
||||
|
||||
// ============= Worker Selection Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_healthy_prefill_worker() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
@@ -70,21 +70,15 @@ impl Router {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Initialize cache-aware policy with workers if needed
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
if default_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = default_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
}
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Get default policy to check if we need load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
@@ -964,19 +958,13 @@ impl Router {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
worker_added = true;
|
||||
}
|
||||
@@ -1000,20 +988,12 @@ impl Router {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, add this worker to it
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
// Get all workers for this model
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
@@ -1084,20 +1064,11 @@ impl Router {
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
// If any models are using cache aware policy, remove the workers from the tree
|
||||
// Check each removed worker's model and get its policy
|
||||
for dp_url in removed_workers.iter() {
|
||||
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
|
||||
let model_id = worker.model_id();
|
||||
if let Some(policy) = self.policy_registry.get_policy(model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(dp_url);
|
||||
info!("Removed worker from cache-aware tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(model_id, dp_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1118,16 +1089,8 @@ impl Router {
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
}
|
||||
|
||||
// If the model is using cache aware policy, remove the worker from the tree
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
info!("Removed worker from cache-aware tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, worker_url);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user