From b24b2e7ed772c80b2dd19c9264adc61736c267b2 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 23 Sep 2025 14:25:53 -0400 Subject: [PATCH] [router] use dashmap for radix tree instead of hash for multi model (#10814) --- sgl-router/src/policies/cache_aware.rs | 211 ++++++++++++------------- 1 file changed, 105 insertions(+), 106 deletions(-) diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 75f03ed04..65244c779 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -63,8 +63,9 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; use crate::core::Worker; use crate::metrics::RouterMetrics; use crate::tree::Tree; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use dashmap::DashMap; +use rand::Rng; +use std::sync::Arc; use std::thread; use std::time::Duration; use tracing::debug; @@ -77,7 +78,7 @@ use tracing::debug; #[derive(Debug)] pub struct CacheAwarePolicy { config: CacheAwareConfig, - trees: Arc>>, // model_id -> Tree + trees: Arc>>, eviction_handle: Option>, } @@ -87,7 +88,7 @@ impl CacheAwarePolicy { } pub fn with_config(config: CacheAwareConfig) -> Self { - let trees = Arc::new(Mutex::new(HashMap::::new())); + let trees = Arc::new(DashMap::>::new()); // Start background eviction thread if configured let eviction_handle = if config.eviction_interval_secs > 0 { @@ -98,15 +99,15 @@ impl CacheAwarePolicy { Some(thread::spawn(move || loop { thread::sleep(Duration::from_secs(interval)); - if let Ok(mut trees_guard) = trees_clone.lock() { - // Evict for all model trees - for (model_id, tree) in trees_guard.iter_mut() { - tree.evict_tenant_by_size(max_tree_size); - debug!( - "Cache eviction completed for model {}, max_size: {}", - model_id, max_tree_size - ); - } + // Evict for all model trees + for tree_ref in trees_clone.iter() { + let model_id = tree_ref.key(); + let tree = tree_ref.value(); + tree.evict_tenant_by_size(max_tree_size); + debug!( + "Cache eviction completed for model {}, max_size: {}", + model_id, max_tree_size + ); } })) } else { @@ -122,93 +123,93 @@ impl CacheAwarePolicy { /// Initialize the tree with worker URLs (used only during initial setup) pub fn init_workers(&self, workers: &[Arc]) { - if let Ok(mut trees) = self.trees.lock() { - // Group workers by model - let mut model_workers: HashMap>> = HashMap::new(); - for worker in workers { - // Use "default" for unknown/empty model_ids for backward compatibility - let model_id = worker.model_id(); - let tree_key = if model_id.is_empty() || model_id == "unknown" { - "default" - } else { - model_id - }; - model_workers - .entry(tree_key.to_string()) - .or_default() - .push(worker); - } + // Group workers by model + let mut model_workers: std::collections::HashMap>> = + std::collections::HashMap::new(); + for worker in workers { + // Use "default" for unknown/empty model_ids for backward compatibility + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default" + } else { + model_id + }; + model_workers + .entry(tree_key.to_string()) + .or_default() + .push(worker); + } - // Initialize tree for each model - for (tree_key, model_workers) in model_workers { - let tree = trees.entry(tree_key).or_insert_with(Tree::new); - for worker in model_workers { - tree.insert("", worker.url()); - } + // Initialize tree for each model + for (tree_key, model_workers) in model_workers { + let tree = self + .trees + .entry(tree_key) + .or_insert_with(|| Arc::new(Tree::new())); + for worker in model_workers { + tree.insert("", worker.url()); } } } /// Add a single worker to the tree (incremental update) pub fn add_worker(&self, worker: &dyn Worker) { - if let Ok(mut trees) = self.trees.lock() { - // For backward compatibility: if model_id is "unknown" or empty, - // use a default tree. This preserves existing behavior for single-model routers. - let model_id = worker.model_id(); - let tree_key = if model_id.is_empty() || model_id == "unknown" { - "default" - } else { - model_id - }; - let tree = trees.entry(tree_key.to_string()).or_insert_with(Tree::new); - tree.insert("", worker.url()); - } + // For backward compatibility: if model_id is "unknown" or empty, + // use a default tree. This preserves existing behavior for single-model routers. + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default" + } else { + model_id + }; + let tree = self + .trees + .entry(tree_key.to_string()) + .or_insert_with(|| Arc::new(Tree::new())); + tree.insert("", worker.url()); } /// Add a worker by URL and model (for backward compatibility) pub fn add_worker_by_url(&self, url: &str, model_id: &str) { - if let Ok(mut trees) = self.trees.lock() { - let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new); - tree.insert("", url); - } + let tree = self + .trees + .entry(model_id.to_string()) + .or_insert_with(|| Arc::new(Tree::new())); + tree.insert("", url); } /// Remove a worker from the tree pub fn remove_worker(&self, worker: &dyn Worker) { - if let Ok(mut trees) = self.trees.lock() { - // Use same logic as add_worker for consistency - let model_id = worker.model_id(); - let tree_key = if model_id.is_empty() || model_id == "unknown" { - "default" - } else { - model_id - }; - if let Some(tree) = trees.get_mut(tree_key) { - tree.remove_tenant(worker.url()); - } + // Use same logic as add_worker for consistency + let model_id = worker.model_id(); + let tree_key = if model_id.is_empty() || model_id == "unknown" { + "default" + } else { + model_id + }; + if let Some(tree) = self.trees.get(tree_key) { + tree.remove_tenant(worker.url()); } } /// Remove a worker by URL (removes from all model trees for backward compatibility) pub fn remove_worker_by_url(&self, url: &str) { - if let Ok(mut trees) = self.trees.lock() { - // Remove from all trees since we don't know which model it belongs to - for (_model_id, tree) in trees.iter_mut() { - tree.remove_tenant(url); - } + // Remove from all trees since we don't know which model it belongs to + for tree_ref in self.trees.iter() { + tree_ref.value().remove_tenant(url); } } /// Run cache eviction to prevent unbounded growth pub fn evict_cache(&self, max_size: usize) { - if let Ok(mut trees) = self.trees.lock() { - for (model_id, tree) in trees.iter_mut() { - tree.evict_tenant_by_size(max_size); - debug!( - "Cache eviction for model {}, max_size: {}", - model_id, max_size - ); - } + for tree_ref in self.trees.iter() { + let model_id = tree_ref.key(); + let tree = tree_ref.value(); + tree.evict_tenant_by_size(max_size); + debug!( + "Cache eviction for model {}, max_size: {}", + model_id, max_size + ); } } } @@ -266,20 +267,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Even in imbalanced mode, update the tree to maintain cache state if let Some(text) = request_text { - if let Ok(mut trees) = self.trees.lock() { - // Avoid allocation if tree already exists - let tree = if let Some(tree) = trees.get_mut(model_id) { - tree - } else { - // Create new tree and initialize with all workers - let new_tree = Tree::new(); - // Initialize with all healthy workers like OLD version does - for &idx in &healthy_indices { - new_tree.insert("", workers[idx].url()); - } - trees.entry(model_id.to_string()).or_insert(new_tree) - }; + // Get the tree reference without locking the entire HashMap + // DashMap only locks the specific shard containing this key + let tree = self.trees.get(model_id).map(|entry| entry.value().clone()); + + if let Some(tree) = tree { + // Now we can work with the tree without holding the HashMap lock tree.insert(text, workers[min_load_idx].url()); + } else { + debug!( + "Warning: No tree found for model '{}', skipping cache update", + model_id + ); } } @@ -294,19 +293,12 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Use cache-aware routing when balanced let text = request_text.unwrap_or(""); - if let Ok(mut trees) = self.trees.lock() { - // Avoid allocation if tree already exists - let tree = if let Some(tree) = trees.get_mut(model_id) { - tree - } else { - // Create new tree and initialize with all workers - let new_tree = Tree::new(); - // Initialize with all healthy workers like OLD version does - for &idx in &healthy_indices { - new_tree.insert("", workers[idx].url()); - } - trees.entry(model_id.to_string()).or_insert(new_tree) - }; + // Get the tree reference without locking the entire HashMap + // DashMap only locks the specific shard containing this key + let tree = self.trees.get(model_id).map(|entry| entry.value().clone()); + + if let Some(tree) = tree { + // Now we work with the tree without holding the HashMap lock let (matched_text, matched_worker) = tree.prefix_match(text); let match_rate = if text.is_empty() { 0.0 @@ -324,7 +316,7 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Find the index of the selected worker if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) { - // Only proceed if the worker is healthy - use direct check like OLD version + // Only proceed if the worker is healthy if workers[selected_idx].is_healthy() { // Update the tree with this request tree.insert(text, &selected_url); @@ -342,11 +334,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy { } // Fallback to first healthy worker - return healthy_indices.first().copied(); + healthy_indices.first().copied() + } else { + // No tree for this model, log warning and use random selection + debug!( + "Warning: No tree found for model '{}', using random worker selection", + model_id + ); + // Return a random healthy worker + let mut rng = rand::rng(); + let random_idx = rng.random_range(0..healthy_indices.len()); + Some(healthy_indices[random_idx]) } - - // Fallback to first healthy worker if tree operations fail - healthy_indices.first().copied() } fn name(&self) -> &'static str {