diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 922ba85e0..47d95c835 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -112,7 +112,7 @@ impl CacheAwarePolicy { } } - /// Initialize the tree with worker URLs + /// Initialize the tree with worker URLs (used only during initial setup) pub fn init_workers(&self, workers: &[Box]) { if let Ok(tree) = self.tree.lock() { for worker in workers { @@ -121,6 +121,13 @@ impl CacheAwarePolicy { } } + /// Add a single worker to the tree (incremental update) + pub fn add_worker(&self, url: &str) { + if let Ok(tree) = self.tree.lock() { + tree.insert("", url); + } + } + /// Remove a worker from the tree pub fn remove_worker(&self, url: &str) { if let Ok(tree) = self.tree.lock() { @@ -178,6 +185,13 @@ impl LoadBalancingPolicy for CacheAwarePolicy { .min_by_key(|&&idx| workers[idx].load()) .copied()?; + // Even in imbalanced mode, update the tree to maintain cache state + if let Some(text) = request_text { + if let Ok(tree) = self.tree.lock() { + tree.insert(text, workers[min_load_idx].url()); + } + } + // Increment processed counter workers[min_load_idx].increment_processed(); RouterMetrics::record_processed_request(workers[min_load_idx].url()); @@ -206,21 +220,26 @@ impl LoadBalancingPolicy for CacheAwarePolicy { }; // Find the index of the selected worker - let selected_idx = workers.iter().position(|w| w.url() == selected_url)?; + if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) { + // Only proceed if the worker is healthy + if workers[selected_idx].is_healthy() { + // Update the tree with this request + tree.insert(text, &selected_url); - // Only proceed if the worker is healthy - if !workers[selected_idx].is_healthy() { - return healthy_indices.first().copied(); + // Increment processed counter + workers[selected_idx].increment_processed(); + RouterMetrics::record_processed_request(&selected_url); + + return Some(selected_idx); + } + } else { + // Selected worker no longer exists, remove it from tree + tree.remove_tenant(&selected_url); + debug!("Removed stale worker {} from cache tree", selected_url); } - // Update the tree with this request - tree.insert(text, &selected_url); - - // Increment processed counter - workers[selected_idx].increment_processed(); - RouterMetrics::record_processed_request(&selected_url); - - return Some(selected_idx); + // Fallback to first healthy worker + return healthy_indices.first().copied(); } // Fallback to first healthy worker if tree operations fail diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 1815f1bfa..ab82c2872 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -7,7 +7,6 @@ use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; use crate::routers::{RouterTrait, WorkerManagement}; -use crate::tree::Tree; use async_trait::async_trait; use axum::{ body::Body, @@ -20,7 +19,7 @@ use futures_util::StreamExt; use reqwest::Client; use serde_json::Value; use std::collections::HashMap; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; @@ -31,8 +30,6 @@ pub struct PDRouter { pub decode_workers: Arc>>>, pub prefill_policy: Arc, pub decode_policy: Arc, - pub prefill_tree: Option>>, - pub decode_tree: Option>>, pub timeout_secs: u64, pub interval_secs: u64, pub worker_loads: Arc>>, @@ -91,9 +88,14 @@ impl PDRouter { workers.push(worker); - // Add to cache tree if using cache-aware policy for prefill - if let Some(ref tree) = self.prefill_tree { - tree.lock().unwrap().insert("", &url); + // Update cache-aware policy if applicable + drop(workers); // Release write lock + if let Some(cache_policy) = self + .prefill_policy + .as_any() + .downcast_ref::() + { + cache_policy.add_worker(&url); } info!("Added prefill server: {}", url); @@ -125,9 +127,14 @@ impl PDRouter { workers.push(worker); - // Add to cache tree if using cache-aware policy for decode - if let Some(ref tree) = self.decode_tree { - tree.lock().unwrap().insert("", &url); + // Update cache-aware policy if applicable + drop(workers); // Release write lock + if let Some(cache_policy) = self + .decode_policy + .as_any() + .downcast_ref::() + { + cache_policy.add_worker(&url); } info!("Added decode server: {}", url); @@ -152,9 +159,13 @@ impl PDRouter { }); } - // Remove from cache tree if using cache-aware policy - if let Some(ref tree) = self.prefill_tree { - tree.lock().unwrap().remove_tenant(url); + // Remove from cache-aware policy if applicable + if let Some(cache_policy) = self + .prefill_policy + .as_any() + .downcast_ref::() + { + cache_policy.remove_worker(url); } info!("Removed prefill server: {}", url); @@ -179,9 +190,13 @@ impl PDRouter { }); } - // Remove from the cache tree if using cache-aware policy for decode - if let Some(ref tree) = self.decode_tree { - tree.lock().unwrap().remove_tenant(url); + // Remove from cache-aware policy if applicable + if let Some(cache_policy) = self + .decode_policy + .as_any() + .downcast_ref::() + { + cache_policy.remove_worker(url); } info!("Removed decode server: {}", url); @@ -238,11 +253,20 @@ impl PDRouter { )?; } - // Initialize cache-aware components if needed for prefill policy - let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?; + // Initialize cache-aware policies with workers + if let Some(cache_policy) = prefill_policy + .as_any() + .downcast_ref::() + { + cache_policy.init_workers(&prefill_workers); + } - // Initialize cache-aware components if needed for decode policy - let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?; + if let Some(cache_policy) = decode_policy + .as_any() + .downcast_ref::() + { + cache_policy.init_workers(&decode_workers); + } // Set up background load monitoring for power-of-two selection let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); @@ -294,8 +318,6 @@ impl PDRouter { decode_workers, prefill_policy, decode_policy, - prefill_tree, - decode_tree, timeout_secs, interval_secs, worker_loads, @@ -309,35 +331,6 @@ impl PDRouter { }) } - // Helper function to initialize radix tree for cache-aware policies - fn initialize_radix_tree( - policy: &Arc, - workers: &[Box], - ) -> Result>>, String> { - if let Some(cache_policy) = policy - .as_any() - .downcast_ref::() - { - // Initialize the policy's internal tree with workers - cache_policy.init_workers(workers); - - let tree = Arc::new(Mutex::new(Tree::new())); - - { - let tree_guard = tree - .lock() - .map_err(|e| format!("Failed to lock tree: {}", e))?; - for worker in workers { - tree_guard.insert("", worker.url()); - } - } - - Ok(Some(tree)) - } else { - Ok(None) - } - } - // Helper to handle server selection errors fn handle_server_selection_error(error: String) -> Response { error!("Failed to select PD pair error={}", error); @@ -1863,8 +1856,6 @@ mod tests { decode_workers: Arc::new(RwLock::new(vec![])), prefill_policy, decode_policy, - prefill_tree: None, - decode_tree: None, timeout_secs: 5, interval_secs: 1, worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), @@ -2002,105 +1993,6 @@ mod tests { } } - // ============= Cache Tree Integration Tests ============= - - #[tokio::test] - async fn test_cache_tree_operations() { - let cache_policy = Arc::new(CacheAwarePolicy::new()); - let mut router = create_test_pd_router(); - router.prefill_policy = cache_policy; - - // Initialize cache tree - let tree = Arc::new(Mutex::new(Tree::new())); - router.prefill_tree = Some(Arc::clone(&tree)); - - // Manually add worker and update tree - let worker = create_test_worker( - "http://worker1".to_string(), - WorkerType::Prefill { - bootstrap_port: None, - }, - true, - ); - router.prefill_workers.write().unwrap().push(worker); - - // Update tree - tree.lock().unwrap().insert("", "http://worker1"); - - // Verify tree contains the worker - let tree_guard = tree.lock().unwrap(); - let (_matched_text, tenant) = tree_guard.prefix_match(""); - // Since we inserted with empty prefix, we should get a match - assert_eq!(tenant, "http://worker1"); - } - - #[tokio::test] - async fn test_cache_tree_rebuild_on_remove() { - let cache_policy = Arc::new(CacheAwarePolicy::new()); - let mut router = create_test_pd_router(); - router.prefill_policy = cache_policy; - - // Initialize cache tree - let tree = Arc::new(Mutex::new(Tree::new())); - router.prefill_tree = Some(Arc::clone(&tree)); - - // Add multiple workers - let worker1 = create_test_worker( - "http://worker1".to_string(), - WorkerType::Prefill { - bootstrap_port: None, - }, - true, - ); - let worker2 = create_test_worker( - "http://worker2".to_string(), - WorkerType::Prefill { - bootstrap_port: None, - }, - true, - ); - - router.prefill_workers.write().unwrap().push(worker1); - router.prefill_workers.write().unwrap().push(worker2); - - // Initialize tree with both workers - { - let tree_guard = tree.lock().unwrap(); - tree_guard.insert("", "http://worker1"); - tree_guard.insert("", "http://worker2"); - } - - // Remove one worker - let result = router.remove_prefill_server("http://worker1").await; - assert!(result.is_ok()); - - // Verify tree only contains remaining worker - let tree_guard = tree.lock().unwrap(); - let (_matched_text, tenant) = tree_guard.prefix_match(""); - // After rebuild, tree should only have worker2 - assert_eq!(tenant, "http://worker2"); - } - - #[tokio::test] - async fn test_no_cache_tree_operations() { - let router = create_test_pd_router(); - assert!(router.prefill_tree.is_none()); - - // Add a worker without cache tree - let worker = create_test_worker( - "http://worker1".to_string(), - WorkerType::Prefill { - bootstrap_port: None, - }, - true, - ); - router.prefill_workers.write().unwrap().push(worker); - - // Remove should work without tree - let result = router.remove_prefill_server("http://worker1").await; - assert!(result.is_ok()); - } - // ============= Bootstrap Injection Tests ============= // Note: These tests are commented out as we've moved to the optimized bootstrap injection // approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.