[router] refactor router and worker management 2.5/n (#10677)
This commit is contained in:
@@ -9,6 +9,7 @@ use super::{
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::types::PolicyConfig;
|
||||
use crate::core::Worker;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -255,6 +256,81 @@ impl PolicyRegistry {
|
||||
.map(Arc::clone)
|
||||
.unwrap_or_else(|| self.get_default_policy())
|
||||
}
|
||||
|
||||
/// Initialize cache-aware policy with workers if applicable
|
||||
/// This should be called after workers are registered for a model
|
||||
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
|
||||
// Get the policy for this model
|
||||
if let Some(policy) = self.get_policy(model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
|
||||
debug!(
|
||||
"Initializing cache-aware policy with {} workers for model {}",
|
||||
workers.len(),
|
||||
model_id
|
||||
);
|
||||
cache_aware.init_workers(workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from cache-aware policy if applicable
|
||||
/// This should be called when a worker is being removed
|
||||
pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) {
|
||||
// Get the policy for this model
|
||||
if let Some(policy) = self.get_policy(model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
debug!(
|
||||
"Removed worker {} from cache-aware policy for model {}",
|
||||
worker_url, model_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize cache-aware policies for PD mode (prefill and decode)
|
||||
pub fn init_pd_cache_aware_policies(
|
||||
&self,
|
||||
prefill_workers: &[Arc<dyn Worker>],
|
||||
decode_workers: &[Arc<dyn Worker>],
|
||||
) {
|
||||
// Initialize prefill policy if it's cache-aware
|
||||
if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() {
|
||||
if prefill_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) =
|
||||
prefill_policy.as_any().downcast_ref::<CacheAwarePolicy>()
|
||||
{
|
||||
if !prefill_workers.is_empty() {
|
||||
debug!(
|
||||
"Initializing prefill cache-aware policy with {} workers",
|
||||
prefill_workers.len()
|
||||
);
|
||||
cache_aware.init_workers(prefill_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize decode policy if it's cache-aware
|
||||
if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() {
|
||||
if decode_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = decode_policy.as_any().downcast_ref::<CacheAwarePolicy>()
|
||||
{
|
||||
if !decode_workers.is_empty() {
|
||||
debug!(
|
||||
"Initializing decode cache-aware policy with {} workers",
|
||||
decode_workers.len()
|
||||
);
|
||||
cache_aware.init_workers(decode_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PolicyRegistry {
|
||||
|
||||
Reference in New Issue
Block a user