[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
399
sgl-router/src/policies/cache_aware.rs
Normal file
399
sgl-router/src/policies/cache_aware.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
|
||||
This router combines two strategies to optimize both cache utilization and request distribution:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
2. Load Balancing (Shortest Queue with Balance Thresholds)
|
||||
|
||||
The router dynamically switches between these strategies based on load conditions:
|
||||
- Uses load balancing when the system is imbalanced
|
||||
- Uses cache-aware routing when the system is balanced
|
||||
|
||||
A system is considered imbalanced if both conditions are met:
|
||||
1. (max - min) > abs_threshold
|
||||
2. max > rel_threshold * min
|
||||
|
||||
Strategy Details:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
-------------------------------------------
|
||||
This strategy maintains an approximate radix tree for each worker based on request history,
|
||||
eliminating the need for direct cache state queries. The tree stores raw text characters
|
||||
instead of token IDs to avoid tokenization overhead.
|
||||
|
||||
Process:
|
||||
a. For each request, find the worker with the highest prefix match
|
||||
b. If match rate > cache_threshold:
|
||||
Route to the worker with highest match (likely has relevant data cached)
|
||||
c. If match rate ≤ cache_threshold:
|
||||
Route to the worker with smallest tree size (most available cache capacity)
|
||||
d. Background maintenance:
|
||||
Periodically evict least recently used leaf nodes to prevent memory overflow
|
||||
|
||||
2. Load Balancing (Shortest Queue)
|
||||
-------------------------------------------
|
||||
This strategy tracks pending request counts per worker and routes new requests
|
||||
to the least busy worker when the system is detected to be imbalanced.
|
||||
|
||||
Configuration Parameters:
|
||||
------------------------
|
||||
1. cache_threshold: (float, 0.0 to 1.0)
|
||||
Minimum prefix match ratio to use highest-match routing.
|
||||
Below this threshold, routes to worker with most available cache space.
|
||||
|
||||
2. balance_abs_threshold: (integer)
|
||||
Absolute difference threshold for load imbalance detection.
|
||||
System is potentially imbalanced if (max_load - min_load) > abs_threshold
|
||||
|
||||
3. balance_rel_threshold: (float)
|
||||
Relative ratio threshold for load imbalance detection.
|
||||
System is potentially imbalanced if max_load > min_load * rel_threshold
|
||||
Used in conjunction with abs_threshold to determine final imbalance state.
|
||||
|
||||
4. eviction_interval_secs: (integer)
|
||||
Interval between LRU eviction cycles for the approximate trees.
|
||||
|
||||
5. max_tree_size: (integer)
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
|
||||
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::tree::Tree;
|
||||
use metrics::{counter, gauge};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Cache-aware routing policy
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
/// switches to shortest-queue routing when load is imbalanced.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheAwarePolicy {
|
||||
config: CacheAwareConfig,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
eviction_handle: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl CacheAwarePolicy {
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CacheAwareConfig::default())
|
||||
}
|
||||
|
||||
pub fn with_config(config: CacheAwareConfig) -> Self {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
|
||||
// Start background eviction thread if configured
|
||||
let eviction_handle = if config.eviction_interval_secs > 0 {
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let max_tree_size = config.max_tree_size;
|
||||
let interval = config.eviction_interval_secs;
|
||||
|
||||
Some(thread::spawn(move || loop {
|
||||
thread::sleep(Duration::from_secs(interval));
|
||||
|
||||
if let Ok(tree_guard) = tree_clone.lock() {
|
||||
tree_guard.evict_tenant_by_size(max_tree_size);
|
||||
debug!("Cache eviction completed, max_size: {}", max_tree_size);
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
config,
|
||||
tree,
|
||||
eviction_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
for worker in workers {
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.remove_tenant(url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run cache eviction to prevent unbounded growth
|
||||
pub fn evict_cache(&self, max_size: usize) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get current load statistics
|
||||
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
let min_load = *loads.iter().min().unwrap_or(&0);
|
||||
|
||||
// Check if load is imbalanced
|
||||
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
|
||||
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
|
||||
|
||||
if is_imbalanced {
|
||||
// Log load balancing trigger
|
||||
let worker_loads: Vec<(String, usize)> = workers
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.load()))
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
"Load balancing triggered due to workload imbalance:\n\
|
||||
Max load: {}, Min load: {}\n\
|
||||
Current worker loads: {:?}",
|
||||
max_load, min_load, worker_loads
|
||||
);
|
||||
|
||||
counter!("sgl_router_load_balancing_events_total").increment(1);
|
||||
gauge!("sgl_router_max_load").set(max_load as f64);
|
||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||
|
||||
// Use shortest queue when imbalanced
|
||||
let min_load_idx = healthy_indices
|
||||
.iter()
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
// Increment processed counter
|
||||
workers[min_load_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string())
|
||||
.increment(1);
|
||||
|
||||
return Some(min_load_idx);
|
||||
}
|
||||
|
||||
// Use cache-aware routing when balanced
|
||||
let text = request_text.unwrap_or("");
|
||||
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
|
||||
let selected_url = if match_rate > self.config.cache_threshold {
|
||||
counter!("sgl_router_cache_hits_total").increment(1);
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
counter!("sgl_router_cache_misses_total").increment(1);
|
||||
tree.get_smallest_tenant()
|
||||
};
|
||||
|
||||
// Find the index of the selected worker
|
||||
let selected_idx = workers.iter().position(|w| w.url() == selected_url)?;
|
||||
|
||||
// Only proceed if the worker is healthy
|
||||
if !workers[selected_idx].is_healthy() {
|
||||
return healthy_indices.first().copied();
|
||||
}
|
||||
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1);
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
healthy_indices.first().copied()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"cache_aware"
|
||||
}
|
||||
|
||||
fn on_request_complete(&self, worker_url: &str, success: bool) {
|
||||
// Could track success rates per worker for more intelligent routing
|
||||
if !success {
|
||||
// Optionally reduce affinity for failed requests
|
||||
tracing::debug!(
|
||||
"Request to {} completed with success={}",
|
||||
worker_url,
|
||||
success
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// In PD mode:
|
||||
// - Prefill: Use cache-aware routing for better cache utilization
|
||||
// - Decode: Use least-load routing for better load distribution
|
||||
|
||||
// Select prefill worker using cache-aware logic
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
|
||||
// Select decode worker using least-load logic
|
||||
let healthy_decode = get_healthy_worker_indices(decode_workers);
|
||||
if healthy_decode.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let decode_idx = healthy_decode
|
||||
.iter()
|
||||
.min_by_key(|&&idx| decode_workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CacheAwarePolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CacheAwarePolicy {
|
||||
fn drop(&mut self) {
|
||||
// Note: We can't properly stop the eviction thread since it's in an infinite loop
|
||||
// In a production system, we'd use a channel or atomic flag to signal shutdown
|
||||
if let Some(handle) = self.eviction_handle.take() {
|
||||
// The thread will continue running until the program exits
|
||||
// This is acceptable for now since the router typically runs for the lifetime of the program
|
||||
drop(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_balanced_load() {
|
||||
// Create policy without eviction thread for testing
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Initialize the policy with workers
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// First request should be distributed
|
||||
let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
|
||||
// Same request should go to same worker (cache hit)
|
||||
let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
assert_eq!(idx1, idx2);
|
||||
|
||||
// Similar request should also go to same worker
|
||||
let idx3 = policy.select_worker(&workers, Some("hello")).unwrap();
|
||||
assert_eq!(idx1, idx3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_imbalanced_load() {
|
||||
let policy = CacheAwarePolicy::with_config(CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 5,
|
||||
balance_rel_threshold: 2.0,
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
max_tree_size: 10000,
|
||||
});
|
||||
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Create significant load imbalance
|
||||
for _ in 0..20 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
// worker2 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Should select worker2 (lower load) despite cache affinity
|
||||
for _ in 0..5 {
|
||||
let idx = policy.select_worker(&workers, Some("test")).unwrap();
|
||||
assert_eq!(idx, 1); // Should always pick worker2
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_worker_removal() {
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Route some requests
|
||||
policy.select_worker(&workers, Some("test1"));
|
||||
policy.select_worker(&workers, Some("test2"));
|
||||
|
||||
// Remove a worker
|
||||
policy.remove_worker("http://w1:8000");
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// All requests should now go to worker2
|
||||
let idx = policy.select_worker(&workers, Some("test1")).unwrap();
|
||||
assert_eq!(idx, 1);
|
||||
}
|
||||
}
|
||||
94
sgl-router/src/policies/factory.rs
Normal file
94
sgl-router/src/policies/factory.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Factory for creating load balancing policies
|
||||
|
||||
use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::PolicyConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating policy instances
|
||||
pub struct PolicyFactory;
|
||||
|
||||
impl PolicyFactory {
|
||||
/// Create a policy from configuration
|
||||
pub fn create_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
|
||||
match config {
|
||||
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
|
||||
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
|
||||
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: *cache_threshold,
|
||||
balance_abs_threshold: *balance_abs_threshold,
|
||||
balance_rel_threshold: *balance_rel_threshold,
|
||||
eviction_interval_secs: *eviction_interval_secs,
|
||||
max_tree_size: *max_tree_size,
|
||||
};
|
||||
Arc::new(CacheAwarePolicy::with_config(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a policy by name (for dynamic loading)
|
||||
pub fn create_by_name(name: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
|
||||
match name.to_lowercase().as_str() {
|
||||
"random" => Some(Arc::new(RandomPolicy::new())),
|
||||
"round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())),
|
||||
"power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())),
|
||||
"cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_from_config() {
|
||||
// Test Random
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
assert_eq!(policy.name(), "random");
|
||||
|
||||
// Test RoundRobin
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
||||
assert_eq!(policy.name(), "round_robin");
|
||||
|
||||
// Test PowerOfTwo
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
});
|
||||
assert_eq!(policy.name(), "power_of_two");
|
||||
|
||||
// Test CacheAware
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 10,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 1000,
|
||||
});
|
||||
assert_eq!(policy.name(), "cache_aware");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_by_name() {
|
||||
assert!(PolicyFactory::create_by_name("random").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RANDOM").is_some());
|
||||
assert!(PolicyFactory::create_by_name("round_robin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RoundRobin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("power_of_two").is_some());
|
||||
assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some());
|
||||
assert!(PolicyFactory::create_by_name("cache_aware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("CacheAware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("unknown").is_none());
|
||||
}
|
||||
}
|
||||
143
sgl-router/src/policies/mod.rs
Normal file
143
sgl-router/src/policies/mod.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Load balancing policies for SGLang router
|
||||
//!
|
||||
//! This module provides a unified abstraction for routing policies that work
|
||||
//! across both regular and prefill-decode (PD) routing modes.
|
||||
|
||||
use crate::core::Worker;
|
||||
use std::fmt::Debug;
|
||||
|
||||
mod cache_aware;
|
||||
mod factory;
|
||||
mod power_of_two;
|
||||
mod random;
|
||||
mod round_robin;
|
||||
|
||||
pub use cache_aware::CacheAwarePolicy;
|
||||
pub use factory::PolicyFactory;
|
||||
pub use power_of_two::PowerOfTwoPolicy;
|
||||
pub use random::RandomPolicy;
|
||||
pub use round_robin::RoundRobinPolicy;
|
||||
|
||||
/// Core trait for load balancing policies
|
||||
///
|
||||
/// This trait provides a unified interface for implementing routing algorithms
|
||||
/// that can work with both regular single-worker selection and PD dual-worker selection.
|
||||
pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Select a single worker from the available workers
|
||||
///
|
||||
/// This is used for regular routing mode where requests go to a single worker.
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize>;
|
||||
|
||||
/// Select a pair of workers (prefill and decode) for PD routing
|
||||
///
|
||||
/// Returns indices of (prefill_worker, decode_worker) from their respective arrays.
|
||||
/// Default implementation uses select_worker for each array independently.
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// Default implementation: independently select from each pool
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
let decode_idx = self.select_worker(decode_workers, request_text)?;
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
|
||||
/// Update policy state after request completion
|
||||
///
|
||||
/// This is called when a request completes (successfully or not) to allow
|
||||
/// policies to update their internal state.
|
||||
fn on_request_complete(&self, _worker_url: &str, _success: bool) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get policy name for metrics and debugging
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Update worker load information
|
||||
///
|
||||
/// This is called periodically with current load information for load-aware policies.
|
||||
fn update_loads(&self, _loads: &std::collections::HashMap<String, isize>) {
|
||||
// Default: no-op for policies that don't use load information
|
||||
}
|
||||
|
||||
/// Reset any internal state
|
||||
///
|
||||
/// This is useful for policies that maintain state (e.g., round-robin counters).
|
||||
fn reset(&self) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
}
|
||||
|
||||
/// Configuration for cache-aware policy
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheAwareConfig {
|
||||
pub cache_threshold: f32,
|
||||
pub balance_abs_threshold: usize,
|
||||
pub balance_rel_threshold: f32,
|
||||
pub eviction_interval_secs: u64,
|
||||
pub max_tree_size: usize,
|
||||
}
|
||||
|
||||
impl Default for CacheAwareConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 10000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to filter healthy workers and return their indices
|
||||
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
|
||||
workers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, w)| w.is_healthy())
|
||||
.map(|(idx, _)| idx)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_get_healthy_worker_indices() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// All healthy initially
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 1, 2]);
|
||||
|
||||
// Mark one unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 2]);
|
||||
}
|
||||
}
|
||||
201
sgl-router/src/policies/power_of_two.rs
Normal file
201
sgl-router/src/policies/power_of_two.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Power-of-two choices load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use metrics::counter;
|
||||
use rand::Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
/// Power-of-two choices policy
|
||||
///
|
||||
/// Randomly selects two workers and routes to the one with lower load.
|
||||
/// This provides good load distribution with minimal coordination overhead.
|
||||
#[derive(Debug)]
|
||||
pub struct PowerOfTwoPolicy {
|
||||
/// Cached load information from external monitoring
|
||||
cached_loads: RwLock<HashMap<String, isize>>,
|
||||
}
|
||||
|
||||
impl PowerOfTwoPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cached_loads: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_load(&self, worker: &dyn Worker) -> isize {
|
||||
// First check cached loads (from external monitoring)
|
||||
if let Ok(loads) = self.cached_loads.read() {
|
||||
if let Some(&load) = loads.get(worker.url()) {
|
||||
return load;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local load counter
|
||||
worker.load() as isize
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for PowerOfTwoPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if healthy_indices.len() == 1 {
|
||||
return Some(healthy_indices[0]);
|
||||
}
|
||||
|
||||
// Select two random workers
|
||||
let mut rng = rand::thread_rng();
|
||||
let idx1 = rng.gen_range(0..healthy_indices.len());
|
||||
let mut idx2 = rng.gen_range(0..healthy_indices.len());
|
||||
|
||||
// Ensure we pick two different workers
|
||||
while idx2 == idx1 {
|
||||
idx2 = rng.gen_range(0..healthy_indices.len());
|
||||
}
|
||||
|
||||
let worker_idx1 = healthy_indices[idx1];
|
||||
let worker_idx2 = healthy_indices[idx2];
|
||||
|
||||
// Compare loads and select the less loaded one
|
||||
let load1 = self.get_worker_load(workers[worker_idx1].as_ref());
|
||||
let load2 = self.get_worker_load(workers[worker_idx2].as_ref());
|
||||
|
||||
// Log selection for debugging
|
||||
let selected_idx = if load1 <= load2 {
|
||||
worker_idx1
|
||||
} else {
|
||||
worker_idx2
|
||||
};
|
||||
|
||||
info!(
|
||||
"Power-of-two selection: {}={} vs {}={} -> selected {}",
|
||||
workers[worker_idx1].url(),
|
||||
load1,
|
||||
workers[worker_idx2].url(),
|
||||
load2,
|
||||
workers[selected_idx].url()
|
||||
);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => workers[selected_idx].url().to_string())
|
||||
.increment(1);
|
||||
|
||||
Some(selected_idx)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"power_of_two"
|
||||
}
|
||||
|
||||
fn update_loads(&self, loads: &HashMap<String, isize>) {
|
||||
if let Ok(mut cached) = self.cached_loads.write() {
|
||||
*cached = loads.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PowerOfTwoPolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_selection() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Set different loads
|
||||
for _ in 0..10 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
for _ in 0..5 {
|
||||
worker2.increment_load();
|
||||
}
|
||||
// worker3 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> =
|
||||
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
||||
|
||||
// Run multiple selections
|
||||
let mut selected_counts = vec![0; 3];
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
selected_counts[idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Worker with lowest load (worker3) should be selected most often
|
||||
assert!(selected_counts[2] > selected_counts[1]);
|
||||
assert!(selected_counts[1] > selected_counts[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_with_cached_loads() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Update cached loads
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://w1:8000".to_string(), 100);
|
||||
loads.insert("http://w2:8000".to_string(), 10);
|
||||
policy.update_loads(&loads);
|
||||
|
||||
// Should prefer worker2 with lower cached load
|
||||
let mut w2_selected = 0;
|
||||
for _ in 0..50 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
if idx == 1 {
|
||||
w2_selected += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Worker2 should be selected significantly more often
|
||||
assert!(w2_selected > 35); // Should win most of the time
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_single_worker() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
// With single worker, should always select it
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
116
sgl-router/src/policies/random.rs
Normal file
116
sgl-router/src/policies/random.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Random load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use rand::Rng;
|
||||
|
||||
/// Random selection policy
|
||||
///
|
||||
/// Selects workers randomly with uniform distribution among healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RandomPolicy;
|
||||
|
||||
impl RandomPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RandomPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let random_idx = rng.gen_range(0..healthy_indices.len());
|
||||
Some(healthy_indices[random_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"random"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_random_selection() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Test multiple selections to ensure randomness
|
||||
let mut counts = HashMap::new();
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
*counts.entry(idx).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// All workers should be selected at least once
|
||||
assert_eq!(counts.len(), 3);
|
||||
assert!(counts.values().all(|&count| count > 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_with_unhealthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark first worker as unhealthy
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// Should always select the healthy worker (index 1)
|
||||
for _ in 0..10 {
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_no_healthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
workers[0].set_healthy(false);
|
||||
assert_eq!(policy.select_worker(&workers, None), None);
|
||||
}
|
||||
}
|
||||
136
sgl-router/src/policies/round_robin.rs
Normal file
136
sgl-router/src/policies/round_robin.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
//! Round-robin load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Round-robin selection policy
|
||||
///
|
||||
/// Selects workers in sequential order, cycling through all healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RoundRobinPolicy {
|
||||
counter: AtomicUsize,
|
||||
}
|
||||
|
||||
impl RoundRobinPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RoundRobinPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get and increment counter atomically
|
||||
let count = self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
let selected_idx = count % healthy_indices.len();
|
||||
|
||||
Some(healthy_indices[selected_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"round_robin"
|
||||
}
|
||||
|
||||
fn reset(&self) {
|
||||
self.counter.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_selection() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_with_unhealthy_workers() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark middle worker as unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
|
||||
// Should skip unhealthy worker: 0, 2, 0, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_reset() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Advance the counter
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
|
||||
// Reset should start from beginning
|
||||
policy.reset();
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user