[router] add different policies for p node and d node (#8395)

This commit is contained in:
Simo Lin
2025-07-27 00:39:20 -07:00
committed by GitHub
parent 0bcc195f4e
commit 2ab97023e3
10 changed files with 536 additions and 81 deletions

View File

@@ -22,8 +22,10 @@ use uuid::Uuid;
pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub policy: Arc<dyn LoadBalancingPolicy>,
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub decode_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64,
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
@@ -66,7 +68,7 @@ impl PDRouter {
workers.push(worker);
// Add to cache tree if using cache-aware policy
// Add to cache tree if using cache-aware policy for prefill
if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().insert("", &url);
}
@@ -102,6 +104,11 @@ impl PDRouter {
workers.push(worker);
// Add to cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().insert("", &url);
}
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
}
@@ -126,12 +133,7 @@ impl PDRouter {
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
// Note: Tree doesn't have a remove method, so we rebuild it
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", worker.url());
}
tree.lock().unwrap().remove_tenant(url);
}
info!("Removed prefill server: {}", url);
@@ -156,6 +158,11 @@ impl PDRouter {
});
}
// Remove from the cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().remove_tenant(url);
}
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
}
@@ -163,7 +170,8 @@ impl PDRouter {
pub fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
) -> Result<Self, String> {
@@ -192,10 +200,10 @@ impl PDRouter {
)?;
}
// Initialize cache-aware components if needed
let prefill_tree = if policy.name() == "cache_aware" {
// Initialize cache-aware components if needed for prefill policy
let prefill_tree = if prefill_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with prefill workers
if let Some(cache_policy) = policy
if let Some(cache_policy) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
@@ -212,6 +220,26 @@ impl PDRouter {
None
};
// Initialize cache-aware components if needed for decode policy
let decode_tree = if decode_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with decode workers
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with decode workers
for worker in &decode_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
@@ -222,25 +250,28 @@ impl PDRouter {
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let load_monitor_handle = if policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
let monitor_interval = interval_secs;
let monitor_client = http_client.clone();
let policy_clone = Arc::clone(&policy);
let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
let monitor_interval = interval_secs;
let monitor_client = http_client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy);
let decode_policy_clone = Arc::clone(&decode_policy);
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client(
monitor_urls,
tx,
monitor_interval,
monitor_client,
policy_clone,
)
.await;
})))
} else {
None
};
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client(
monitor_urls,
tx,
monitor_interval,
monitor_client,
prefill_policy_clone,
decode_policy_clone,
)
.await;
})))
} else {
None
};
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
@@ -254,8 +285,10 @@ impl PDRouter {
Ok(PDRouter {
prefill_workers,
decode_workers,
policy,
prefill_policy,
decode_policy,
prefill_tree,
decode_tree,
timeout_secs,
interval_secs,
worker_loads,
@@ -736,18 +769,21 @@ impl PDRouter {
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
}
// Use the policy to select worker pair
match self
.policy
.select_worker_pair(&prefill_workers, &decode_workers, request_text)
{
Some((prefill_idx, decode_idx)) => {
let prefill = prefill_workers[prefill_idx].clone_worker();
let decode = decode_workers[decode_idx].clone_worker();
Ok((prefill, decode))
}
None => Err("Failed to select worker pair".to_string()),
}
// Select prefill worker using prefill policy
let prefill_idx = self
.prefill_policy
.select_worker(&prefill_workers, request_text)
.ok_or("Failed to select prefill worker")?;
// Select decode worker using decode policy
let decode_idx = self
.decode_policy
.select_worker(&decode_workers, request_text)
.ok_or("Failed to select decode worker")?;
let prefill = prefill_workers[prefill_idx].clone_worker();
let decode = decode_workers[decode_idx].clone_worker();
Ok((prefill, decode))
}
// Background task to monitor worker loads with shared client
@@ -756,7 +792,8 @@ impl PDRouter {
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
client: reqwest::Client,
policy: Arc<dyn LoadBalancingPolicy>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
) {
loop {
let mut loads = HashMap::new();
@@ -781,8 +818,9 @@ impl PDRouter {
debug!("Worker loads updated: {:?}", loads);
// Update the policy with current loads
policy.update_loads(&loads);
// Update both policies with current loads
prefill_policy.update_loads(&loads);
decode_policy.update_loads(&loads);
// Check if receiver is still active
if tx.send(loads).is_err() {
@@ -1463,13 +1501,16 @@ mod tests {
use actix_web::test::TestRequest;
fn create_test_pd_router() -> PDRouter {
let policy = Arc::new(RandomPolicy::new());
let prefill_policy = Arc::new(RandomPolicy::new());
let decode_policy = Arc::new(RandomPolicy::new());
PDRouter {
prefill_workers: Arc::new(RwLock::new(vec![])),
decode_workers: Arc::new(RwLock::new(vec![])),
policy,
prefill_policy,
decode_policy,
prefill_tree: None,
decode_tree: None,
timeout_secs: 5,
interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
@@ -1608,9 +1649,9 @@ mod tests {
#[tokio::test]
async fn test_cache_tree_operations() {
let policy = Arc::new(CacheAwarePolicy::new());
let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
router.prefill_policy = cache_policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
@@ -1638,9 +1679,9 @@ mod tests {
#[tokio::test]
async fn test_cache_tree_rebuild_on_remove() {
let policy = Arc::new(CacheAwarePolicy::new());
let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
router.prefill_policy = cache_policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
@@ -1880,9 +1921,10 @@ mod tests {
#[tokio::test]
async fn test_load_monitor_updates() {
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
router.prefill_policy = power_of_two_policy.clone();
router.decode_policy = power_of_two_policy;
// Create load channel
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());