[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::tree::Tree;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
@@ -72,10 +73,11 @@ use tracing::debug;
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
/// switches to shortest-queue routing when load is imbalanced.
|
||||
/// Maintains separate trees per model for multi-model support.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheAwarePolicy {
|
||||
config: CacheAwareConfig,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
trees: Arc<Mutex<HashMap<String, Tree>>>, // model_id -> Tree
|
||||
eviction_handle: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
@@ -85,20 +87,26 @@ impl CacheAwarePolicy {
|
||||
}
|
||||
|
||||
pub fn with_config(config: CacheAwareConfig) -> Self {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
let trees = Arc::new(Mutex::new(HashMap::<String, Tree>::new()));
|
||||
|
||||
// Start background eviction thread if configured
|
||||
let eviction_handle = if config.eviction_interval_secs > 0 {
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let trees_clone = Arc::clone(&trees);
|
||||
let max_tree_size = config.max_tree_size;
|
||||
let interval = config.eviction_interval_secs;
|
||||
|
||||
Some(thread::spawn(move || loop {
|
||||
thread::sleep(Duration::from_secs(interval));
|
||||
|
||||
if let Ok(tree_guard) = tree_clone.lock() {
|
||||
tree_guard.evict_tenant_by_size(max_tree_size);
|
||||
debug!("Cache eviction completed, max_size: {}", max_tree_size);
|
||||
if let Ok(mut trees_guard) = trees_clone.lock() {
|
||||
// Evict for all model trees
|
||||
for (model_id, tree) in trees_guard.iter_mut() {
|
||||
tree.evict_tenant_by_size(max_tree_size);
|
||||
debug!(
|
||||
"Cache eviction completed for model {}, max_size: {}",
|
||||
model_id, max_tree_size
|
||||
);
|
||||
}
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
@@ -107,38 +115,97 @@ impl CacheAwarePolicy {
|
||||
|
||||
Self {
|
||||
config,
|
||||
tree,
|
||||
trees,
|
||||
eviction_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs (used only during initial setup)
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
pub fn init_workers(&self, workers: &[Arc<dyn Worker>]) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Group workers by model
|
||||
let mut model_workers: HashMap<String, Vec<&Arc<dyn Worker>>> = HashMap::new();
|
||||
for worker in workers {
|
||||
tree.insert("", worker.url());
|
||||
// Use "default" for unknown/empty model_ids for backward compatibility
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
model_workers.entry(tree_key).or_default().push(worker);
|
||||
}
|
||||
|
||||
// Initialize tree for each model
|
||||
for (tree_key, model_workers) in model_workers {
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
for worker in model_workers {
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single worker to the tree (incremental update)
|
||||
pub fn add_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
pub fn add_worker(&self, worker: &dyn Worker) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// For backward compatibility: if model_id is "unknown" or empty,
|
||||
// use a default tree. This preserves existing behavior for single-model routers.
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a worker by URL and model (for backward compatibility)
|
||||
pub fn add_worker_by_url(&self, url: &str, model_id: &str) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new);
|
||||
tree.insert("", url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.remove_tenant(url);
|
||||
pub fn remove_worker(&self, worker: &dyn Worker) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Use same logic as add_worker for consistency
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
if let Some(tree) = trees.get_mut(&tree_key) {
|
||||
tree.remove_tenant(worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker by URL (removes from all model trees for backward compatibility)
|
||||
pub fn remove_worker_by_url(&self, url: &str) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Remove from all trees since we don't know which model it belongs to
|
||||
for (_model_id, tree) in trees.iter_mut() {
|
||||
tree.remove_tenant(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run cache eviction to prevent unbounded growth
|
||||
pub fn evict_cache(&self, max_size: usize) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
for (model_id, tree) in trees.iter_mut() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
debug!(
|
||||
"Cache eviction for model {}, max_size: {}",
|
||||
model_id, max_size
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,7 +213,7 @@ impl CacheAwarePolicy {
|
||||
impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
@@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Group workers by model (using "default" for unknown/empty model_ids)
|
||||
let mut model_workers: HashMap<String, Vec<usize>> = HashMap::new();
|
||||
for idx in &healthy_indices {
|
||||
let model_id = workers[*idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
model_workers.entry(tree_key).or_default().push(*idx);
|
||||
}
|
||||
|
||||
// Get current load statistics
|
||||
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
@@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
|
||||
// Even in imbalanced mode, update the tree to maintain cache state
|
||||
if let Some(text) = request_text {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let model_id = workers[min_load_idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert(text, workers[min_load_idx].url());
|
||||
}
|
||||
}
|
||||
@@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
// Use cache-aware routing when balanced
|
||||
let text = request_text.unwrap_or("");
|
||||
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let mut best_match_idx: Option<usize> = None;
|
||||
let mut best_match_rate: f32 = 0.0;
|
||||
|
||||
let selected_url = if match_rate > self.config.cache_threshold {
|
||||
RouterMetrics::record_cache_hit();
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
RouterMetrics::record_cache_miss();
|
||||
tree.get_smallest_tenant()
|
||||
};
|
||||
// Find best match across all models
|
||||
for (model_id, worker_indices) in &model_workers {
|
||||
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
|
||||
|
||||
// Find the index of the selected worker
|
||||
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
|
||||
// Only proceed if the worker is healthy
|
||||
if workers[selected_idx].is_healthy() {
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(&selected_url);
|
||||
|
||||
return Some(selected_idx);
|
||||
// Check if this model has the best match
|
||||
if match_rate > best_match_rate {
|
||||
// Find the worker index for this URL
|
||||
if let Some(idx) = worker_indices
|
||||
.iter()
|
||||
.find(|&&idx| workers[idx].url() == matched_worker)
|
||||
{
|
||||
best_match_idx = Some(*idx);
|
||||
best_match_rate = match_rate;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Selected worker no longer exists, remove it from tree
|
||||
tree.remove_tenant(&selected_url);
|
||||
debug!("Removed stale worker {} from cache tree", selected_url);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker
|
||||
return healthy_indices.first().copied();
|
||||
// Select worker based on cache threshold
|
||||
let selected_idx = if let (Some(idx), true) = (
|
||||
best_match_idx,
|
||||
best_match_rate > self.config.cache_threshold,
|
||||
) {
|
||||
RouterMetrics::record_cache_hit();
|
||||
idx
|
||||
} else {
|
||||
RouterMetrics::record_cache_miss();
|
||||
|
||||
// Find model with smallest tree (most cache capacity)
|
||||
let mut smallest_tree_model = String::new();
|
||||
let mut smallest_tree_size = usize::MAX;
|
||||
|
||||
for model_id in model_workers.keys() {
|
||||
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
|
||||
let size = tree.get_used_size_per_tenant().values().sum::<usize>();
|
||||
if size < smallest_tree_size {
|
||||
smallest_tree_size = size;
|
||||
smallest_tree_model = model_id.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Select least loaded worker from model with most cache capacity
|
||||
if let Some(worker_indices) = model_workers.get(&smallest_tree_model) {
|
||||
worker_indices
|
||||
.iter()
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()
|
||||
.unwrap_or(healthy_indices[0])
|
||||
} else {
|
||||
healthy_indices[0]
|
||||
}
|
||||
};
|
||||
|
||||
// Update the tree with this request
|
||||
let model_id = workers[selected_idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert(text, workers[selected_idx].url());
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(workers[selected_idx].url());
|
||||
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
@@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
prefill_workers: &[Arc<dyn Worker>],
|
||||
decode_workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// DEPRECATED: This method is no longer used when separate policies are configured.
|
||||
@@ -333,12 +461,12 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -378,7 +506,7 @@ mod tests {
|
||||
}
|
||||
// worker2 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1), Arc::new(worker2)];
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Should select worker2 (lower load) despite cache affinity
|
||||
@@ -395,12 +523,12 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -413,7 +541,7 @@ mod tests {
|
||||
policy.select_worker(&workers, Some("test2"));
|
||||
|
||||
// Remove a worker
|
||||
policy.remove_worker("http://w1:8000");
|
||||
policy.remove_worker_by_url("http://w1:8000");
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// All requests should now go to worker2
|
||||
|
||||
Reference in New Issue
Block a user