[router] fix radix tree integration issues in PD router (#8982)
This commit is contained in:
@@ -112,7 +112,7 @@ impl CacheAwarePolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs
|
||||
/// Initialize the tree with worker URLs (used only during initial setup)
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
for worker in workers {
|
||||
@@ -121,6 +121,13 @@ impl CacheAwarePolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single worker to the tree (incremental update)
|
||||
pub fn add_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.insert("", url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
@@ -178,6 +185,13 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
// Even in imbalanced mode, update the tree to maintain cache state
|
||||
if let Some(text) = request_text {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.insert(text, workers[min_load_idx].url());
|
||||
}
|
||||
}
|
||||
|
||||
// Increment processed counter
|
||||
workers[min_load_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(workers[min_load_idx].url());
|
||||
@@ -206,21 +220,26 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
};
|
||||
|
||||
// Find the index of the selected worker
|
||||
let selected_idx = workers.iter().position(|w| w.url() == selected_url)?;
|
||||
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
|
||||
// Only proceed if the worker is healthy
|
||||
if workers[selected_idx].is_healthy() {
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
|
||||
// Only proceed if the worker is healthy
|
||||
if !workers[selected_idx].is_healthy() {
|
||||
return healthy_indices.first().copied();
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(&selected_url);
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
} else {
|
||||
// Selected worker no longer exists, remove it from tree
|
||||
tree.remove_tenant(&selected_url);
|
||||
debug!("Removed stale worker {} from cache tree", selected_url);
|
||||
}
|
||||
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(&selected_url);
|
||||
|
||||
return Some(selected_idx);
|
||||
// Fallback to first healthy worker
|
||||
return healthy_indices.first().copied();
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
|
||||
@@ -7,7 +7,6 @@ use crate::metrics::RouterMetrics;
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tree::Tree;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -20,7 +19,7 @@ use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -31,8 +30,6 @@ pub struct PDRouter {
|
||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub decode_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub timeout_secs: u64,
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
@@ -91,9 +88,14 @@ impl PDRouter {
|
||||
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy for prefill
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
tree.lock().unwrap().insert("", &url);
|
||||
// Update cache-aware policy if applicable
|
||||
drop(workers); // Release write lock
|
||||
if let Some(cache_policy) = self
|
||||
.prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.add_worker(&url);
|
||||
}
|
||||
|
||||
info!("Added prefill server: {}", url);
|
||||
@@ -125,9 +127,14 @@ impl PDRouter {
|
||||
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy for decode
|
||||
if let Some(ref tree) = self.decode_tree {
|
||||
tree.lock().unwrap().insert("", &url);
|
||||
// Update cache-aware policy if applicable
|
||||
drop(workers); // Release write lock
|
||||
if let Some(cache_policy) = self
|
||||
.decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.add_worker(&url);
|
||||
}
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
@@ -152,9 +159,13 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
tree.lock().unwrap().remove_tenant(url);
|
||||
// Remove from cache-aware policy if applicable
|
||||
if let Some(cache_policy) = self
|
||||
.prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.remove_worker(url);
|
||||
}
|
||||
|
||||
info!("Removed prefill server: {}", url);
|
||||
@@ -179,9 +190,13 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from the cache tree if using cache-aware policy for decode
|
||||
if let Some(ref tree) = self.decode_tree {
|
||||
tree.lock().unwrap().remove_tenant(url);
|
||||
// Remove from cache-aware policy if applicable
|
||||
if let Some(cache_policy) = self
|
||||
.decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.remove_worker(url);
|
||||
}
|
||||
|
||||
info!("Removed decode server: {}", url);
|
||||
@@ -238,11 +253,20 @@ impl PDRouter {
|
||||
)?;
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed for prefill policy
|
||||
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
|
||||
// Initialize cache-aware policies with workers
|
||||
if let Some(cache_policy) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&prefill_workers);
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed for decode policy
|
||||
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
|
||||
if let Some(cache_policy) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
// Set up background load monitoring for power-of-two selection
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
@@ -294,8 +318,6 @@ impl PDRouter {
|
||||
decode_workers,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
prefill_tree,
|
||||
decode_tree,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
worker_loads,
|
||||
@@ -309,35 +331,6 @@ impl PDRouter {
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to initialize radix tree for cache-aware policies
|
||||
fn initialize_radix_tree(
|
||||
policy: &Arc<dyn LoadBalancingPolicy>,
|
||||
workers: &[Box<dyn Worker>],
|
||||
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
|
||||
if let Some(cache_policy) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
// Initialize the policy's internal tree with workers
|
||||
cache_policy.init_workers(workers);
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
|
||||
{
|
||||
let tree_guard = tree
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock tree: {}", e))?;
|
||||
for worker in workers {
|
||||
tree_guard.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(tree))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to handle server selection errors
|
||||
fn handle_server_selection_error(error: String) -> Response {
|
||||
error!("Failed to select PD pair error={}", error);
|
||||
@@ -1863,8 +1856,6 @@ mod tests {
|
||||
decode_workers: Arc::new(RwLock::new(vec![])),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
prefill_tree: None,
|
||||
decode_tree: None,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
@@ -2002,105 +1993,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Cache Tree Integration Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_operations() {
|
||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.prefill_policy = cache_policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
router.prefill_tree = Some(Arc::clone(&tree));
|
||||
|
||||
// Manually add worker and update tree
|
||||
let worker = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
router.prefill_workers.write().unwrap().push(worker);
|
||||
|
||||
// Update tree
|
||||
tree.lock().unwrap().insert("", "http://worker1");
|
||||
|
||||
// Verify tree contains the worker
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
||||
// Since we inserted with empty prefix, we should get a match
|
||||
assert_eq!(tenant, "http://worker1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_tree_rebuild_on_remove() {
|
||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
||||
let mut router = create_test_pd_router();
|
||||
router.prefill_policy = cache_policy;
|
||||
|
||||
// Initialize cache tree
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
router.prefill_tree = Some(Arc::clone(&tree));
|
||||
|
||||
// Add multiple workers
|
||||
let worker1 = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let worker2 = create_test_worker(
|
||||
"http://worker2".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
router.prefill_workers.write().unwrap().push(worker1);
|
||||
router.prefill_workers.write().unwrap().push(worker2);
|
||||
|
||||
// Initialize tree with both workers
|
||||
{
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
tree_guard.insert("", "http://worker1");
|
||||
tree_guard.insert("", "http://worker2");
|
||||
}
|
||||
|
||||
// Remove one worker
|
||||
let result = router.remove_prefill_server("http://worker1").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify tree only contains remaining worker
|
||||
let tree_guard = tree.lock().unwrap();
|
||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
||||
// After rebuild, tree should only have worker2
|
||||
assert_eq!(tenant, "http://worker2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_cache_tree_operations() {
|
||||
let router = create_test_pd_router();
|
||||
assert!(router.prefill_tree.is_none());
|
||||
|
||||
// Add a worker without cache tree
|
||||
let worker = create_test_worker(
|
||||
"http://worker1".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
true,
|
||||
);
|
||||
router.prefill_workers.write().unwrap().push(worker);
|
||||
|
||||
// Remove should work without tree
|
||||
let result = router.remove_prefill_server("http://worker1").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
// ============= Bootstrap Injection Tests =============
|
||||
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
||||
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
||||
|
||||
Reference in New Issue
Block a user