[router] fix radix tree integration issues in PD router (#8982)
This commit is contained in:
@@ -112,7 +112,7 @@ impl CacheAwarePolicy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize the tree with worker URLs
|
/// Initialize the tree with worker URLs (used only during initial setup)
|
||||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||||
if let Ok(tree) = self.tree.lock() {
|
if let Ok(tree) = self.tree.lock() {
|
||||||
for worker in workers {
|
for worker in workers {
|
||||||
@@ -121,6 +121,13 @@ impl CacheAwarePolicy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a single worker to the tree (incremental update)
|
||||||
|
pub fn add_worker(&self, url: &str) {
|
||||||
|
if let Ok(tree) = self.tree.lock() {
|
||||||
|
tree.insert("", url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Remove a worker from the tree
|
/// Remove a worker from the tree
|
||||||
pub fn remove_worker(&self, url: &str) {
|
pub fn remove_worker(&self, url: &str) {
|
||||||
if let Ok(tree) = self.tree.lock() {
|
if let Ok(tree) = self.tree.lock() {
|
||||||
@@ -178,6 +185,13 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
|||||||
.min_by_key(|&&idx| workers[idx].load())
|
.min_by_key(|&&idx| workers[idx].load())
|
||||||
.copied()?;
|
.copied()?;
|
||||||
|
|
||||||
|
// Even in imbalanced mode, update the tree to maintain cache state
|
||||||
|
if let Some(text) = request_text {
|
||||||
|
if let Ok(tree) = self.tree.lock() {
|
||||||
|
tree.insert(text, workers[min_load_idx].url());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Increment processed counter
|
// Increment processed counter
|
||||||
workers[min_load_idx].increment_processed();
|
workers[min_load_idx].increment_processed();
|
||||||
RouterMetrics::record_processed_request(workers[min_load_idx].url());
|
RouterMetrics::record_processed_request(workers[min_load_idx].url());
|
||||||
@@ -206,21 +220,26 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Find the index of the selected worker
|
// Find the index of the selected worker
|
||||||
let selected_idx = workers.iter().position(|w| w.url() == selected_url)?;
|
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
|
||||||
|
// Only proceed if the worker is healthy
|
||||||
|
if workers[selected_idx].is_healthy() {
|
||||||
|
// Update the tree with this request
|
||||||
|
tree.insert(text, &selected_url);
|
||||||
|
|
||||||
// Only proceed if the worker is healthy
|
// Increment processed counter
|
||||||
if !workers[selected_idx].is_healthy() {
|
workers[selected_idx].increment_processed();
|
||||||
return healthy_indices.first().copied();
|
RouterMetrics::record_processed_request(&selected_url);
|
||||||
|
|
||||||
|
return Some(selected_idx);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Selected worker no longer exists, remove it from tree
|
||||||
|
tree.remove_tenant(&selected_url);
|
||||||
|
debug!("Removed stale worker {} from cache tree", selected_url);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the tree with this request
|
// Fallback to first healthy worker
|
||||||
tree.insert(text, &selected_url);
|
return healthy_indices.first().copied();
|
||||||
|
|
||||||
// Increment processed counter
|
|
||||||
workers[selected_idx].increment_processed();
|
|
||||||
RouterMetrics::record_processed_request(&selected_url);
|
|
||||||
|
|
||||||
return Some(selected_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to first healthy worker if tree operations fail
|
// Fallback to first healthy worker if tree operations fail
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ use crate::metrics::RouterMetrics;
|
|||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use crate::tree::Tree;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -20,7 +19,7 @@ use futures_util::StreamExt;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
@@ -31,8 +30,6 @@ pub struct PDRouter {
|
|||||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
|
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
|
||||||
pub decode_tree: Option<Arc<Mutex<Tree>>>,
|
|
||||||
pub timeout_secs: u64,
|
pub timeout_secs: u64,
|
||||||
pub interval_secs: u64,
|
pub interval_secs: u64,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
@@ -91,9 +88,14 @@ impl PDRouter {
|
|||||||
|
|
||||||
workers.push(worker);
|
workers.push(worker);
|
||||||
|
|
||||||
// Add to cache tree if using cache-aware policy for prefill
|
// Update cache-aware policy if applicable
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
drop(workers); // Release write lock
|
||||||
tree.lock().unwrap().insert("", &url);
|
if let Some(cache_policy) = self
|
||||||
|
.prefill_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.add_worker(&url);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Added prefill server: {}", url);
|
info!("Added prefill server: {}", url);
|
||||||
@@ -125,9 +127,14 @@ impl PDRouter {
|
|||||||
|
|
||||||
workers.push(worker);
|
workers.push(worker);
|
||||||
|
|
||||||
// Add to cache tree if using cache-aware policy for decode
|
// Update cache-aware policy if applicable
|
||||||
if let Some(ref tree) = self.decode_tree {
|
drop(workers); // Release write lock
|
||||||
tree.lock().unwrap().insert("", &url);
|
if let Some(cache_policy) = self
|
||||||
|
.decode_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.add_worker(&url);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Added decode server: {}", url);
|
info!("Added decode server: {}", url);
|
||||||
@@ -152,9 +159,13 @@ impl PDRouter {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from cache tree if using cache-aware policy
|
// Remove from cache-aware policy if applicable
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
if let Some(cache_policy) = self
|
||||||
tree.lock().unwrap().remove_tenant(url);
|
.prefill_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.remove_worker(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Removed prefill server: {}", url);
|
info!("Removed prefill server: {}", url);
|
||||||
@@ -179,9 +190,13 @@ impl PDRouter {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from the cache tree if using cache-aware policy for decode
|
// Remove from cache-aware policy if applicable
|
||||||
if let Some(ref tree) = self.decode_tree {
|
if let Some(cache_policy) = self
|
||||||
tree.lock().unwrap().remove_tenant(url);
|
.decode_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.remove_worker(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Removed decode server: {}", url);
|
info!("Removed decode server: {}", url);
|
||||||
@@ -238,11 +253,20 @@ impl PDRouter {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize cache-aware components if needed for prefill policy
|
// Initialize cache-aware policies with workers
|
||||||
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
|
if let Some(cache_policy) = prefill_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.init_workers(&prefill_workers);
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize cache-aware components if needed for decode policy
|
if let Some(cache_policy) = decode_policy
|
||||||
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_policy.init_workers(&decode_workers);
|
||||||
|
}
|
||||||
|
|
||||||
// Set up background load monitoring for power-of-two selection
|
// Set up background load monitoring for power-of-two selection
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
@@ -294,8 +318,6 @@ impl PDRouter {
|
|||||||
decode_workers,
|
decode_workers,
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
prefill_tree,
|
|
||||||
decode_tree,
|
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
worker_loads,
|
worker_loads,
|
||||||
@@ -309,35 +331,6 @@ impl PDRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to initialize radix tree for cache-aware policies
|
|
||||||
fn initialize_radix_tree(
|
|
||||||
policy: &Arc<dyn LoadBalancingPolicy>,
|
|
||||||
workers: &[Box<dyn Worker>],
|
|
||||||
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
|
|
||||||
if let Some(cache_policy) = policy
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
|
||||||
{
|
|
||||||
// Initialize the policy's internal tree with workers
|
|
||||||
cache_policy.init_workers(workers);
|
|
||||||
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
|
||||||
|
|
||||||
{
|
|
||||||
let tree_guard = tree
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| format!("Failed to lock tree: {}", e))?;
|
|
||||||
for worker in workers {
|
|
||||||
tree_guard.insert("", worker.url());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(tree))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to handle server selection errors
|
// Helper to handle server selection errors
|
||||||
fn handle_server_selection_error(error: String) -> Response {
|
fn handle_server_selection_error(error: String) -> Response {
|
||||||
error!("Failed to select PD pair error={}", error);
|
error!("Failed to select PD pair error={}", error);
|
||||||
@@ -1863,8 +1856,6 @@ mod tests {
|
|||||||
decode_workers: Arc::new(RwLock::new(vec![])),
|
decode_workers: Arc::new(RwLock::new(vec![])),
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
prefill_tree: None,
|
|
||||||
decode_tree: None,
|
|
||||||
timeout_secs: 5,
|
timeout_secs: 5,
|
||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||||
@@ -2002,105 +1993,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Cache Tree Integration Tests =============
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_cache_tree_operations() {
|
|
||||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
|
||||||
let mut router = create_test_pd_router();
|
|
||||||
router.prefill_policy = cache_policy;
|
|
||||||
|
|
||||||
// Initialize cache tree
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
|
||||||
router.prefill_tree = Some(Arc::clone(&tree));
|
|
||||||
|
|
||||||
// Manually add worker and update tree
|
|
||||||
let worker = create_test_worker(
|
|
||||||
"http://worker1".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
router.prefill_workers.write().unwrap().push(worker);
|
|
||||||
|
|
||||||
// Update tree
|
|
||||||
tree.lock().unwrap().insert("", "http://worker1");
|
|
||||||
|
|
||||||
// Verify tree contains the worker
|
|
||||||
let tree_guard = tree.lock().unwrap();
|
|
||||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
|
||||||
// Since we inserted with empty prefix, we should get a match
|
|
||||||
assert_eq!(tenant, "http://worker1");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_cache_tree_rebuild_on_remove() {
|
|
||||||
let cache_policy = Arc::new(CacheAwarePolicy::new());
|
|
||||||
let mut router = create_test_pd_router();
|
|
||||||
router.prefill_policy = cache_policy;
|
|
||||||
|
|
||||||
// Initialize cache tree
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
|
||||||
router.prefill_tree = Some(Arc::clone(&tree));
|
|
||||||
|
|
||||||
// Add multiple workers
|
|
||||||
let worker1 = create_test_worker(
|
|
||||||
"http://worker1".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
let worker2 = create_test_worker(
|
|
||||||
"http://worker2".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
router.prefill_workers.write().unwrap().push(worker1);
|
|
||||||
router.prefill_workers.write().unwrap().push(worker2);
|
|
||||||
|
|
||||||
// Initialize tree with both workers
|
|
||||||
{
|
|
||||||
let tree_guard = tree.lock().unwrap();
|
|
||||||
tree_guard.insert("", "http://worker1");
|
|
||||||
tree_guard.insert("", "http://worker2");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove one worker
|
|
||||||
let result = router.remove_prefill_server("http://worker1").await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
// Verify tree only contains remaining worker
|
|
||||||
let tree_guard = tree.lock().unwrap();
|
|
||||||
let (_matched_text, tenant) = tree_guard.prefix_match("");
|
|
||||||
// After rebuild, tree should only have worker2
|
|
||||||
assert_eq!(tenant, "http://worker2");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_no_cache_tree_operations() {
|
|
||||||
let router = create_test_pd_router();
|
|
||||||
assert!(router.prefill_tree.is_none());
|
|
||||||
|
|
||||||
// Add a worker without cache tree
|
|
||||||
let worker = create_test_worker(
|
|
||||||
"http://worker1".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
router.prefill_workers.write().unwrap().push(worker);
|
|
||||||
|
|
||||||
// Remove should work without tree
|
|
||||||
let result = router.remove_prefill_server("http://worker1").await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Bootstrap Injection Tests =============
|
// ============= Bootstrap Injection Tests =============
|
||||||
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
||||||
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
||||||
|
|||||||
Reference in New Issue
Block a user