[router] add different policies for p node and d node (#8395)
This commit is contained in:
@@ -22,8 +22,10 @@ use uuid::Uuid;
|
||||
pub struct PDRouter {
|
||||
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub decode_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub timeout_secs: u64,
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
@@ -66,7 +68,7 @@ impl PDRouter {
|
||||
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy
|
||||
// Add to cache tree if using cache-aware policy for prefill
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
tree.lock().unwrap().insert("", &url);
|
||||
}
|
||||
@@ -102,6 +104,11 @@ impl PDRouter {
|
||||
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy for decode
|
||||
if let Some(ref tree) = self.decode_tree {
|
||||
tree.lock().unwrap().insert("", &url);
|
||||
}
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
}
|
||||
@@ -126,12 +133,7 @@ impl PDRouter {
|
||||
|
||||
// Remove from cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
// Note: Tree doesn't have a remove method, so we rebuild it
|
||||
let mut tree_guard = tree.lock().unwrap();
|
||||
*tree_guard = Tree::new();
|
||||
for worker in workers.iter() {
|
||||
tree_guard.insert("", worker.url());
|
||||
}
|
||||
tree.lock().unwrap().remove_tenant(url);
|
||||
}
|
||||
|
||||
info!("Removed prefill server: {}", url);
|
||||
@@ -156,6 +158,11 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from the cache tree if using cache-aware policy for decode
|
||||
if let Some(ref tree) = self.decode_tree {
|
||||
tree.lock().unwrap().remove_tenant(url);
|
||||
}
|
||||
|
||||
info!("Removed decode server: {}", url);
|
||||
Ok(format!("Successfully removed decode server: {}", url))
|
||||
}
|
||||
@@ -163,7 +170,8 @@ impl PDRouter {
|
||||
pub fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self, String> {
|
||||
@@ -192,10 +200,10 @@ impl PDRouter {
|
||||
)?;
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed
|
||||
let prefill_tree = if policy.name() == "cache_aware" {
|
||||
// Initialize cache-aware components if needed for prefill policy
|
||||
let prefill_tree = if prefill_policy.name() == "cache_aware" {
|
||||
// Initialize the policy's internal tree with prefill workers
|
||||
if let Some(cache_policy) = policy
|
||||
if let Some(cache_policy) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
@@ -212,6 +220,26 @@ impl PDRouter {
|
||||
None
|
||||
};
|
||||
|
||||
// Initialize cache-aware components if needed for decode policy
|
||||
let decode_tree = if decode_policy.name() == "cache_aware" {
|
||||
// Initialize the policy's internal tree with decode workers
|
||||
if let Some(cache_policy) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with decode workers
|
||||
for worker in &decode_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Set up background load monitoring for power-of-two selection
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
@@ -222,25 +250,28 @@ impl PDRouter {
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let load_monitor_handle = if policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_interval = interval_secs;
|
||||
let monitor_client = http_client.clone();
|
||||
let policy_clone = Arc::clone(&policy);
|
||||
let load_monitor_handle =
|
||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_interval = interval_secs;
|
||||
let monitor_client = http_client.clone();
|
||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||
let decode_policy_clone = Arc::clone(&decode_policy);
|
||||
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads_with_client(
|
||||
monitor_urls,
|
||||
tx,
|
||||
monitor_interval,
|
||||
monitor_client,
|
||||
policy_clone,
|
||||
)
|
||||
.await;
|
||||
})))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads_with_client(
|
||||
monitor_urls,
|
||||
tx,
|
||||
monitor_interval,
|
||||
monitor_client,
|
||||
prefill_policy_clone,
|
||||
decode_policy_clone,
|
||||
)
|
||||
.await;
|
||||
})))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
@@ -254,8 +285,10 @@ impl PDRouter {
|
||||
Ok(PDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
policy,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
prefill_tree,
|
||||
decode_tree,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
worker_loads,
|
||||
@@ -736,18 +769,21 @@ impl PDRouter {
|
||||
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
|
||||
}
|
||||
|
||||
// Use the policy to select worker pair
|
||||
match self
|
||||
.policy
|
||||
.select_worker_pair(&prefill_workers, &decode_workers, request_text)
|
||||
{
|
||||
Some((prefill_idx, decode_idx)) => {
|
||||
let prefill = prefill_workers[prefill_idx].clone_worker();
|
||||
let decode = decode_workers[decode_idx].clone_worker();
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
None => Err("Failed to select worker pair".to_string()),
|
||||
}
|
||||
// Select prefill worker using prefill policy
|
||||
let prefill_idx = self
|
||||
.prefill_policy
|
||||
.select_worker(&prefill_workers, request_text)
|
||||
.ok_or("Failed to select prefill worker")?;
|
||||
|
||||
// Select decode worker using decode policy
|
||||
let decode_idx = self
|
||||
.decode_policy
|
||||
.select_worker(&decode_workers, request_text)
|
||||
.ok_or("Failed to select decode worker")?;
|
||||
|
||||
let prefill = prefill_workers[prefill_idx].clone_worker();
|
||||
let decode = decode_workers[decode_idx].clone_worker();
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
|
||||
// Background task to monitor worker loads with shared client
|
||||
@@ -756,7 +792,8 @@ impl PDRouter {
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
client: reqwest::Client,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
) {
|
||||
loop {
|
||||
let mut loads = HashMap::new();
|
||||
@@ -781,8 +818,9 @@ impl PDRouter {
|
||||
|
||||
debug!("Worker loads updated: {:?}", loads);
|
||||
|
||||
// Update the policy with current loads
|
||||
policy.update_loads(&loads);
|
||||
// Update both policies with current loads
|
||||
prefill_policy.update_loads(&loads);
|
||||
decode_policy.update_loads(&loads);
|
||||
|
||||
// Check if receiver is still active
|
||||
if tx.send(loads).is_err() {
|
||||
@@ -1463,13 +1501,16 @@ mod tests {
|
||||
use actix_web::test::TestRequest;
|
||||
|
||||
fn create_test_pd_router() -> PDRouter {
|
||||
let policy = Arc::new(RandomPolicy::new());
|
||||
let prefill_policy = Arc::new(RandomPolicy::new());
|
||||
let decode_policy = Arc::new(RandomPolicy::new());
|
||||
|
||||
PDRouter {
|
||||
prefill_workers: Arc::new(RwLock::new(vec![])),
|
||||
decode_workers: Arc::new(RwLock::new(vec![])),
|
||||
policy,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
prefill_tree: None,
|
||||
decode_tree: None,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
@@ -1608,9 +1649,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_operations() {
|
||||
let policy = Arc::new(CacheAwarePolicy::new());
|
||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
router.prefill_policy = cache_policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
@@ -1638,9 +1679,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_rebuild_on_remove() {
|
||||
let policy = Arc::new(CacheAwarePolicy::new());
|
||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
router.prefill_policy = cache_policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
@@ -1880,9 +1921,10 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_monitor_updates() {
|
||||
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
||||
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.policy = policy;
|
||||
router.prefill_policy = power_of_two_policy.clone();
|
||||
router.decode_policy = power_of_two_policy;
|
||||
|
||||
// Create load channel
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
|
||||
Reference in New Issue
Block a user