[router] allow one router to support different model families and serving mode (#10244)

This commit is contained in:
Simo Lin
2025-09-12 19:18:27 -04:00
committed by GitHub
parent 321fecab74
commit 2f173ea074
28 changed files with 3528 additions and 837 deletions

View File

@@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
@@ -72,10 +73,11 @@ use tracing::debug;
///
/// Routes requests based on cache affinity when load is balanced,
/// switches to shortest-queue routing when load is imbalanced.
/// Maintains separate trees per model for multi-model support.
#[derive(Debug)]
pub struct CacheAwarePolicy {
config: CacheAwareConfig,
tree: Arc<Mutex<Tree>>,
trees: Arc<Mutex<HashMap<String, Tree>>>, // model_id -> Tree
eviction_handle: Option<thread::JoinHandle<()>>,
}
@@ -85,20 +87,26 @@ impl CacheAwarePolicy {
}
pub fn with_config(config: CacheAwareConfig) -> Self {
let tree = Arc::new(Mutex::new(Tree::new()));
let trees = Arc::new(Mutex::new(HashMap::<String, Tree>::new()));
// Start background eviction thread if configured
let eviction_handle = if config.eviction_interval_secs > 0 {
let tree_clone = Arc::clone(&tree);
let trees_clone = Arc::clone(&trees);
let max_tree_size = config.max_tree_size;
let interval = config.eviction_interval_secs;
Some(thread::spawn(move || loop {
thread::sleep(Duration::from_secs(interval));
if let Ok(tree_guard) = tree_clone.lock() {
tree_guard.evict_tenant_by_size(max_tree_size);
debug!("Cache eviction completed, max_size: {}", max_tree_size);
if let Ok(mut trees_guard) = trees_clone.lock() {
// Evict for all model trees
for (model_id, tree) in trees_guard.iter_mut() {
tree.evict_tenant_by_size(max_tree_size);
debug!(
"Cache eviction completed for model {}, max_size: {}",
model_id, max_tree_size
);
}
}
}))
} else {
@@ -107,38 +115,97 @@ impl CacheAwarePolicy {
Self {
config,
tree,
trees,
eviction_handle,
}
}
/// Initialize the tree with worker URLs (used only during initial setup)
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
if let Ok(tree) = self.tree.lock() {
pub fn init_workers(&self, workers: &[Arc<dyn Worker>]) {
if let Ok(mut trees) = self.trees.lock() {
// Group workers by model
let mut model_workers: HashMap<String, Vec<&Arc<dyn Worker>>> = HashMap::new();
for worker in workers {
tree.insert("", worker.url());
// Use "default" for unknown/empty model_ids for backward compatibility
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
model_workers.entry(tree_key).or_default().push(worker);
}
// Initialize tree for each model
for (tree_key, model_workers) in model_workers {
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
for worker in model_workers {
tree.insert("", worker.url());
}
}
}
}
/// Add a single worker to the tree (incremental update)
pub fn add_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
pub fn add_worker(&self, worker: &dyn Worker) {
if let Ok(mut trees) = self.trees.lock() {
// For backward compatibility: if model_id is "unknown" or empty,
// use a default tree. This preserves existing behavior for single-model routers.
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert("", worker.url());
}
}
/// Add a worker by URL and model (for backward compatibility)
pub fn add_worker_by_url(&self, url: &str, model_id: &str) {
if let Ok(mut trees) = self.trees.lock() {
let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new);
tree.insert("", url);
}
}
/// Remove a worker from the tree
pub fn remove_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
tree.remove_tenant(url);
pub fn remove_worker(&self, worker: &dyn Worker) {
if let Ok(mut trees) = self.trees.lock() {
// Use same logic as add_worker for consistency
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
if let Some(tree) = trees.get_mut(&tree_key) {
tree.remove_tenant(worker.url());
}
}
}
/// Remove a worker by URL (removes from all model trees for backward compatibility)
pub fn remove_worker_by_url(&self, url: &str) {
if let Ok(mut trees) = self.trees.lock() {
// Remove from all trees since we don't know which model it belongs to
for (_model_id, tree) in trees.iter_mut() {
tree.remove_tenant(url);
}
}
}
/// Run cache eviction to prevent unbounded growth
pub fn evict_cache(&self, max_size: usize) {
if let Ok(tree) = self.tree.lock() {
tree.evict_tenant_by_size(max_size);
if let Ok(mut trees) = self.trees.lock() {
for (model_id, tree) in trees.iter_mut() {
tree.evict_tenant_by_size(max_size);
debug!(
"Cache eviction for model {}, max_size: {}",
model_id, max_size
);
}
}
}
}
@@ -146,7 +213,7 @@ impl CacheAwarePolicy {
impl LoadBalancingPolicy for CacheAwarePolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
@@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
return None;
}
// Group workers by model (using "default" for unknown/empty model_ids)
let mut model_workers: HashMap<String, Vec<usize>> = HashMap::new();
for idx in &healthy_indices {
let model_id = workers[*idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
model_workers.entry(tree_key).or_default().push(*idx);
}
// Get current load statistics
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
let max_load = *loads.iter().max().unwrap_or(&0);
@@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Even in imbalanced mode, update the tree to maintain cache state
if let Some(text) = request_text {
if let Ok(tree) = self.tree.lock() {
if let Ok(mut trees) = self.trees.lock() {
let model_id = workers[min_load_idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert(text, workers[min_load_idx].url());
}
}
@@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Use cache-aware routing when balanced
let text = request_text.unwrap_or("");
if let Ok(tree) = self.tree.lock() {
let (matched_text, matched_worker) = tree.prefix_match(text);
let match_rate = if text.is_empty() {
0.0
} else {
matched_text.chars().count() as f32 / text.chars().count() as f32
};
if let Ok(mut trees) = self.trees.lock() {
let mut best_match_idx: Option<usize> = None;
let mut best_match_rate: f32 = 0.0;
let selected_url = if match_rate > self.config.cache_threshold {
RouterMetrics::record_cache_hit();
matched_worker.to_string()
} else {
RouterMetrics::record_cache_miss();
tree.get_smallest_tenant()
};
// Find best match across all models
for (model_id, worker_indices) in &model_workers {
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
// Find the index of the selected worker
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
// Only proceed if the worker is healthy
if workers[selected_idx].is_healthy() {
// Update the tree with this request
tree.insert(text, &selected_url);
let (matched_text, matched_worker) = tree.prefix_match(text);
let match_rate = if text.is_empty() {
0.0
} else {
matched_text.chars().count() as f32 / text.chars().count() as f32
};
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(&selected_url);
return Some(selected_idx);
// Check if this model has the best match
if match_rate > best_match_rate {
// Find the worker index for this URL
if let Some(idx) = worker_indices
.iter()
.find(|&&idx| workers[idx].url() == matched_worker)
{
best_match_idx = Some(*idx);
best_match_rate = match_rate;
}
}
} else {
// Selected worker no longer exists, remove it from tree
tree.remove_tenant(&selected_url);
debug!("Removed stale worker {} from cache tree", selected_url);
}
// Fallback to first healthy worker
return healthy_indices.first().copied();
// Select worker based on cache threshold
let selected_idx = if let (Some(idx), true) = (
best_match_idx,
best_match_rate > self.config.cache_threshold,
) {
RouterMetrics::record_cache_hit();
idx
} else {
RouterMetrics::record_cache_miss();
// Find model with smallest tree (most cache capacity)
let mut smallest_tree_model = String::new();
let mut smallest_tree_size = usize::MAX;
for model_id in model_workers.keys() {
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
let size = tree.get_used_size_per_tenant().values().sum::<usize>();
if size < smallest_tree_size {
smallest_tree_size = size;
smallest_tree_model = model_id.clone();
}
}
// Select least loaded worker from model with most cache capacity
if let Some(worker_indices) = model_workers.get(&smallest_tree_model) {
worker_indices
.iter()
.min_by_key(|&&idx| workers[idx].load())
.copied()
.unwrap_or(healthy_indices[0])
} else {
healthy_indices[0]
}
};
// Update the tree with this request
let model_id = workers[selected_idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert(text, workers[selected_idx].url());
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(workers[selected_idx].url());
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
return Some(selected_idx);
}
// Fallback to first healthy worker if tree operations fail
@@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// DEPRECATED: This method is no longer used when separate policies are configured.
@@ -333,12 +461,12 @@ mod tests {
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
@@ -378,7 +506,7 @@ mod tests {
}
// worker2 has load 0
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1), Arc::new(worker2)];
policy.init_workers(&workers);
// Should select worker2 (lower load) despite cache affinity
@@ -395,12 +523,12 @@ mod tests {
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
@@ -413,7 +541,7 @@ mod tests {
policy.select_worker(&workers, Some("test2"));
// Remove a worker
policy.remove_worker("http://w1:8000");
policy.remove_worker_by_url("http://w1:8000");
workers[0].set_healthy(false);
// All requests should now go to worker2

View File

@@ -5,17 +5,20 @@
use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware;
mod factory;
mod power_of_two;
mod random;
mod registry;
mod round_robin;
pub use cache_aware::CacheAwarePolicy;
pub use factory::PolicyFactory;
pub use power_of_two::PowerOfTwoPolicy;
pub use random::RandomPolicy;
pub use registry::PolicyRegistry;
pub use round_robin::RoundRobinPolicy;
/// Core trait for load balancing policies
@@ -26,9 +29,10 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Select a single worker from the available workers
///
/// This is used for regular routing mode where requests go to a single worker.
/// Now uses Arc<dyn Worker> for better performance and to avoid unnecessary cloning.
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize>;
@@ -38,8 +42,8 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Default implementation uses select_worker for each array independently.
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// Default implementation: independently select from each pool
@@ -105,7 +109,7 @@ impl Default for CacheAwareConfig {
}
/// Helper function to filter healthy workers and return their indices
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usize> {
workers
.iter()
.enumerate()
@@ -121,16 +125,16 @@ mod tests {
#[test]
fn test_get_healthy_worker_indices() {
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),

View File

@@ -5,7 +5,7 @@ use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::{Arc, RwLock};
use tracing::info;
/// Power-of-two choices policy
@@ -41,7 +41,7 @@ impl PowerOfTwoPolicy {
impl LoadBalancingPolicy for PowerOfTwoPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
@@ -137,8 +137,8 @@ mod tests {
}
// worker3 has load 0
let workers: Vec<Box<dyn Worker>> =
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
let workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker1), Arc::new(worker2), Arc::new(worker3)];
// Run multiple selections
let mut selected_counts = [0; 3];
@@ -156,12 +156,12 @@ mod tests {
#[test]
fn test_power_of_two_with_cached_loads() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
@@ -190,7 +190,7 @@ mod tests {
#[test]
fn test_power_of_two_single_worker() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];

View File

@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc;
/// Random selection policy
///
@@ -20,7 +21,7 @@ impl RandomPolicy {
impl LoadBalancingPolicy for RandomPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
@@ -56,16 +57,16 @@ mod tests {
#[test]
fn test_random_selection() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
@@ -87,12 +88,12 @@ mod tests {
#[test]
fn test_random_with_unhealthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
@@ -110,7 +111,7 @@ mod tests {
#[test]
fn test_random_no_healthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];

View File

@@ -0,0 +1,333 @@
/// Policy Registry for managing model-to-policy mappings
///
/// This registry manages the dynamic assignment of load balancing policies to models.
/// When the first worker of a new model is added, it determines the policy for that model.
/// All subsequent workers of the same model use the established policy.
/// When the last worker of a model is removed, the policy mapping is cleaned up.
use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
/// Registry for managing model-to-policy mappings
#[derive(Clone)]
pub struct PolicyRegistry {
/// Model ID -> Policy instance mapping
model_policies: Arc<RwLock<HashMap<String, Arc<dyn LoadBalancingPolicy>>>>,
/// Model ID -> Worker count for cleanup tracking
model_worker_counts: Arc<RwLock<HashMap<String, usize>>>,
/// Default policy instance (cached)
default_policy: Arc<dyn LoadBalancingPolicy>,
/// Prefill policy for PD mode
prefill_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
/// Decode policy for PD mode
decode_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
}
impl PolicyRegistry {
/// Create a new PolicyRegistry with a default policy
pub fn new(default_policy_config: PolicyConfig) -> Self {
let default_policy = Self::create_policy_from_config(&default_policy_config);
Self {
model_policies: Arc::new(RwLock::new(HashMap::new())),
model_worker_counts: Arc::new(RwLock::new(HashMap::new())),
default_policy,
prefill_policy: Arc::new(RwLock::new(None)),
decode_policy: Arc::new(RwLock::new(None)),
}
}
/// Called when a worker is added
/// Returns the policy that should be used for this worker's model
pub fn on_worker_added(
&self,
model_id: &str,
policy_hint: Option<&str>,
) -> Arc<dyn LoadBalancingPolicy> {
// Increment worker count
{
let mut counts = self.model_worker_counts.write().unwrap();
*counts.entry(model_id.to_string()).or_insert(0) += 1;
debug!(
"Worker added for model {}, count: {}",
model_id,
counts.get(model_id).unwrap()
);
}
// Check if model already has a policy
{
let policies = self.model_policies.read().unwrap();
if let Some(existing_policy) = policies.get(model_id) {
debug!(
"Model {} already has policy: {}",
model_id,
existing_policy.name()
);
return Arc::clone(existing_policy);
}
}
// New model - determine policy
let policy = self.determine_policy_for_model(model_id, policy_hint);
info!(
"Assigning policy {} to new model {}",
policy.name(),
model_id
);
// Store policy for this model
{
let mut policies = self.model_policies.write().unwrap();
policies.insert(model_id.to_string(), Arc::clone(&policy));
}
policy
}
/// Called when a worker is removed
pub fn on_worker_removed(&self, model_id: &str) {
let should_cleanup = {
let mut counts = self.model_worker_counts.write().unwrap();
if let Some(count) = counts.get_mut(model_id) {
*count = count.saturating_sub(1);
debug!("Worker removed for model {}, count: {}", model_id, *count);
if *count == 0 {
counts.remove(model_id);
true
} else {
false
}
} else {
warn!(
"Attempted to remove worker for model {} with no registered workers",
model_id
);
false
}
};
// Clean up policy if this was the last worker
if should_cleanup {
let mut policies = self.model_policies.write().unwrap();
if let Some(policy) = policies.remove(model_id) {
info!(
"Removed policy {} for model {} (last worker removed)",
policy.name(),
model_id
);
// Policy will be dropped here, cleaning up any resources
drop(policy);
}
}
}
/// Get the policy for a model
pub fn get_policy(&self, model_id: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
self.model_policies.read().unwrap().get(model_id).cloned()
}
/// Get the default policy
pub fn get_default_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
Arc::clone(&self.default_policy)
}
/// Get policy for a model, or default if not found
pub fn get_policy_or_default(&self, model_id: &str) -> Arc<dyn LoadBalancingPolicy> {
self.get_policy(model_id)
.unwrap_or_else(|| self.get_default_policy())
}
/// Determine policy for a new model
fn determine_policy_for_model(
&self,
model_id: &str,
policy_hint: Option<&str>,
) -> Arc<dyn LoadBalancingPolicy> {
// 1. Check policy hint from worker
if let Some(policy_type) = policy_hint {
debug!("Using policy hint '{}' for model {}", policy_type, model_id);
return self.create_policy_from_type(policy_type);
}
// 2. Use default policy
debug!("Using default policy for model {}", model_id);
Arc::clone(&self.default_policy)
}
/// Create a policy from a type string
fn create_policy_from_type(&self, policy_type: &str) -> Arc<dyn LoadBalancingPolicy> {
match policy_type {
"round_robin" => Arc::new(RoundRobinPolicy::new()),
"random" => Arc::new(RandomPolicy::new()),
"cache_aware" => Arc::new(CacheAwarePolicy::new()),
"power_of_two" => Arc::new(PowerOfTwoPolicy::new()),
_ => {
warn!("Unknown policy type '{}', using default", policy_type);
Arc::clone(&self.default_policy)
}
}
}
/// Create a policy from a PolicyConfig
fn create_policy_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
match config {
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
let cache_config = CacheAwareConfig {
cache_threshold: *cache_threshold,
balance_abs_threshold: *balance_abs_threshold,
balance_rel_threshold: *balance_rel_threshold,
eviction_interval_secs: *eviction_interval_secs,
max_tree_size: *max_tree_size,
};
Arc::new(CacheAwarePolicy::with_config(cache_config))
}
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
}
}
/// Get current model->policy mappings (for debugging/monitoring)
pub fn get_all_mappings(&self) -> HashMap<String, String> {
let policies = self.model_policies.read().unwrap();
policies
.iter()
.map(|(model, policy)| (model.clone(), policy.name().to_string()))
.collect()
}
/// Get worker counts per model
pub fn get_worker_counts(&self) -> HashMap<String, usize> {
self.model_worker_counts.read().unwrap().clone()
}
/// Clear all policies (useful for testing)
pub fn clear(&self) {
let mut policies = self.model_policies.write().unwrap();
policies.clear();
let mut counts = self.model_worker_counts.write().unwrap();
counts.clear();
}
/// Set the prefill policy for PD mode
pub fn set_prefill_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
let mut prefill_policy = self.prefill_policy.write().unwrap();
*prefill_policy = Some(policy);
}
/// Set the decode policy for PD mode
pub fn set_decode_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
let mut decode_policy = self.decode_policy.write().unwrap();
*decode_policy = Some(policy);
}
/// Get the prefill policy for PD mode, or default if not set
pub fn get_prefill_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
let prefill_policy = self.prefill_policy.read().unwrap();
prefill_policy
.as_ref()
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
/// Get the decode policy for PD mode, or default if not set
pub fn get_decode_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
let decode_policy = self.decode_policy.read().unwrap();
decode_policy
.as_ref()
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
}
impl std::fmt::Debug for PolicyRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PolicyRegistry")
.field("model_policies", &self.model_policies)
.field("model_worker_counts", &self.model_worker_counts)
.field("default_policy", &self.default_policy.name())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_registry_basic() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// First worker of a model sets the policy
let policy1 = registry.on_worker_added("llama-3", Some("cache_aware"));
assert_eq!(policy1.name(), "cache_aware");
// Second worker of same model uses existing policy
let policy2 = registry.on_worker_added("llama-3", Some("round_robin"));
assert_eq!(policy2.name(), "cache_aware"); // Ignores hint, uses existing
// Different model can have different policy
let policy3 = registry.on_worker_added("gpt-4", Some("random"));
assert_eq!(policy3.name(), "random");
// Check mappings
let mappings = registry.get_all_mappings();
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
// Check worker counts
let counts = registry.get_worker_counts();
assert_eq!(*counts.get("llama-3").unwrap(), 2);
assert_eq!(*counts.get("gpt-4").unwrap(), 1);
}
#[test]
fn test_policy_registry_cleanup() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// Add workers
registry.on_worker_added("llama-3", Some("cache_aware"));
registry.on_worker_added("llama-3", None);
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&2));
// Remove one worker - policy should remain
registry.on_worker_removed("llama-3");
assert!(registry.get_policy("llama-3").is_some());
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&1));
// Remove last worker - policy should be cleaned up
registry.on_worker_removed("llama-3");
assert!(registry.get_policy("llama-3").is_none());
assert_eq!(registry.get_worker_counts().get("llama-3"), None);
}
#[test]
fn test_default_policy() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// No hint, no template - uses default
let policy = registry.on_worker_added("unknown-model", None);
assert_eq!(policy.name(), "round_robin");
// Get default directly
let default = registry.get_default_policy();
assert_eq!(default.name(), "round_robin");
}
}

View File

@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
/// Round-robin selection policy
///
@@ -24,7 +25,7 @@ impl RoundRobinPolicy {
impl LoadBalancingPolicy for RoundRobinPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
@@ -64,16 +65,16 @@ mod tests {
#[test]
fn test_round_robin_selection() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
@@ -90,16 +91,16 @@ mod tests {
#[test]
fn test_round_robin_with_unhealthy_workers() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
@@ -118,12 +119,12 @@ mod tests {
#[test]
fn test_round_robin_reset() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),