[router] cache-aware load-balancing router v1 (#2114)

This commit is contained in:
Byron Hsu
2024-11-23 08:34:48 -08:00
committed by GitHub
parent ad47749b82
commit cbedd1db1d
17 changed files with 1963 additions and 602 deletions

View File

@@ -1,7 +1,6 @@
// Python Binding
use pyo3::prelude::*;
pub mod router;
mod server;
pub mod server;
pub mod tree;
#[pyclass(eq)]
@@ -9,7 +8,7 @@ pub mod tree;
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
CacheAware,
}
#[pyclass]
@@ -18,8 +17,10 @@ struct Router {
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
}
#[pymethods]
@@ -30,33 +31,30 @@ impl Router {
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
tokenizer_path = None,
cache_threshold = Some(0.50)
cache_threshold = 0.50,
cache_routing_prob = 1.0,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24)
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router {
host,
port,
worker_urls,
policy,
tokenizer_path,
cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
})
}
@@ -68,14 +66,11 @@ impl Router {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
};

View File

@@ -1,18 +1,14 @@
// src/main.rs
use clap::Parser;
use clap::ValueEnum;
// declare child modules
mod router;
mod server;
mod tree;
use crate::router::PolicyConfig;
use sglang_router_rs::{router::PolicyConfig, server};
#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
CacheAware,
}
#[derive(Parser, Debug)]
@@ -21,44 +17,70 @@ struct Args {
#[arg(
long,
default_value = "127.0.0.1",
help = "Host address to bind the server to"
help = "Host address to bind the router server to. Default: 127.0.0.1"
)]
host: String,
#[arg(long, default_value_t = 3001, help = "Port number to listen on")]
#[arg(
long,
default_value_t = 3001,
help = "Port number to bind the router server to. Default: 3001"
)]
port: u16,
#[arg(
long,
value_delimiter = ',',
help = "Comma-separated list of worker URLs to distribute requests to"
help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
)]
worker_urls: Vec<String>,
#[arg(
long,
default_value_t = PolicyType::RoundRobin,
default_value_t = PolicyType::CacheAware,
value_enum,
help = "Load balancing policy to use: random, round_robin, or approx_tree"
help = "Load balancing policy to use for request distribution:\n\
- random: Randomly select workers\n\
- round_robin: Distribute requests in round-robin fashion\n\
- cache_aware: Distribute requests in cache-aware fashion\n"
)]
policy: PolicyType,
#[arg(
long,
default_value_t = 0.5,
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Path to the tokenizer file, required when using approx_tree policy"
required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
)]
tokenizer_path: Option<String>,
cache_threshold: f32,
#[arg(
long,
default_value = "0.50",
default_value_t = 1.0,
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
required_if_eq("policy", "cache_aware"),
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
)]
cache_threshold: Option<f32>,
cache_routing_prob: f32,
#[arg(
long,
default_value_t = 60,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
)]
eviction_interval_secs: u64,
#[arg(
long,
default_value_t = 2usize.pow(24),
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)]
max_tree_size: usize,
}
impl Args {
@@ -66,14 +88,11 @@ impl Args {
match self.policy {
PolicyType::Random => PolicyConfig::RandomConfig,
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
}
}

View File

@@ -1,13 +1,16 @@
use crate::tree::RadixTree;
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::TryStreamExt;
use futures_util::{Stream, StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub enum Router {
@@ -18,34 +21,88 @@ pub enum Router {
Random {
worker_urls: Vec<String>,
},
ApproxTree {
CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue)
For each incoming request, the router chooses between these strategies:
- With probability P: Uses cache-aware routing
- With probability (1-P): Uses load balancing
where P is configured via `cache_routing_prob`
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution.
Configuration Parameters:
------------------------
1. cache_routing_prob: (float, 0.0 to 1.0)
- 0.0: Exclusively use load balancing
- 1.0: Exclusively use cache-aware routing
- Between 0-1: Probability of using cache-aware routing vs load balancing
2. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
3. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
4. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
// TODO: don't lock the whole tree
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
tokenizer: Tokenizer,
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
cache_routing_prob: f32,
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
},
}
#[derive(Debug)]
pub enum PolicyConfig {
RandomConfig,
RoundRobinConfig,
ApproxTreeConfig {
tokenizer_path: String,
CacheAwareConfig {
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
},
}
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
fn get_text_from_request(body: &Bytes) -> String {
// 1. convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
// 3. tokenize the text field
let tokens = tokenizer.encode(text, false).unwrap();
tokens.get_ids().to_vec()
return text.to_string();
}
impl Router {
@@ -56,25 +113,56 @@ impl Router {
worker_urls,
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::ApproxTreeConfig {
tokenizer_path,
PolicyConfig::CacheAwareConfig {
cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
} => {
let mut url_to_tree = HashMap::new();
let mut url_to_count = HashMap::new();
let mut running_queue = HashMap::new();
for url in &worker_urls {
url_to_tree.insert(url.clone(), RadixTree::new());
url_to_count.insert(url.clone(), 0);
running_queue.insert(url.clone(), 0);
}
Router::ApproxTree {
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_data(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue);
}
});
for url in &worker_urls {
tree.lock().unwrap().insert(&"".to_string(), url);
}
Router::CacheAware {
worker_urls,
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
url_to_count: Arc::new(Mutex::new(url_to_count)),
tree,
running_queue,
processed_queue,
cache_threshold,
cache_routing_prob,
_eviction_thread: Some(eviction_thread),
}
}
}
@@ -84,7 +172,7 @@ impl Router {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::ApproxTree { worker_urls, .. } => {
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
None
} else {
@@ -100,10 +188,7 @@ impl Router {
req: HttpRequest,
body: Bytes,
) -> HttpResponse {
let mut input_ids: Vec<u32> = Vec::new();
if let Router::ApproxTree { tokenizer, .. } = self {
input_ids = get_token_ids_from_request(&body, tokenizer);
}
let text = get_text_from_request(&body);
let worker_url = match self {
Router::RoundRobin {
@@ -125,78 +210,73 @@ impl Router {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::ApproxTree {
Router::CacheAware {
worker_urls,
url_to_tree,
url_to_count,
tree,
running_queue,
processed_queue,
cache_threshold,
cache_routing_prob,
..
} => {
// TODO: pipeline the locks. Release one earlier.
// even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
let mut max_matched_rate = 0.0;
let mut max_matched_idx = 0;
let mut tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
let locked_url_to_tree = url_to_tree.lock().unwrap();
// Generate a random float between 0 and 1 for probability check
let sampled_p: f32 = rand::random();
// 1. Find the highest matched worker
for (i, url) in worker_urls.iter().enumerate() {
let tree = locked_url_to_tree.get(url).unwrap();
let matched = tree.prefix_match(&input_ids[..]).len();
let matched_rate = matched as f32 / input_ids.len() as f32;
let selected_url = if sampled_p < *cache_routing_prob {
// Cache-aware routing logic
let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched_rate =
matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > max_matched_rate {
max_matched_rate = matched_rate;
max_matched_idx = i;
if matched_rate > *cache_threshold {
matched_worker.to_string()
} else {
let m_map: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
tree.get_smallest_tenant()
}
}
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
if max_matched_rate > *cache_threshold {
worker_urls[max_matched_idx].clone()
} else {
// pick the shortest queue from url_to_count
let locked_url_to_count = url_to_count.lock().unwrap();
// Shortest queue routing logic
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
};
let mut min_count = std::usize::MAX;
let mut min_count_id = 0;
// Update running queue
let count = running_queue.get_mut(&selected_url).unwrap();
*count += 1;
for (i, url) in worker_urls.iter().enumerate() {
let count = locked_url_to_count.get(url).unwrap();
if *count < min_count {
min_count = *count;
min_count_id = i;
}
}
// Update processed queue
let mut locked_processed_queue = processed_queue.lock().unwrap();
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
*count += 1;
worker_urls[min_count_id].clone()
}
// Update tree with the new request
tree.insert(&text, &selected_url);
selected_url
}
};
if let Router::ApproxTree {
url_to_tree,
url_to_count,
..
} = self
{
// Insert input_ids to the tree
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
selected_tree.insert(&input_ids[..]);
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count += 1;
}
// Check if client requested streaming
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);
let res = match client
.post(format!("{}/generate", worker_url))
.post(format!("{}/generate", worker_url.clone()))
.header(
"Content-Type",
req.headers()
@@ -216,23 +296,53 @@ impl Router {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// TODO: do the correction on the tree based on the cached input_ids
if let Router::ApproxTree { url_to_count, .. } = self {
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count -= 1;
}
match res.bytes().await {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
*count = count.saturating_sub(1);
}
}
}
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
// print
// println!("streaming is done!!")
}
}),
)
} else {
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read string")
actix_web::error::ErrorInternalServerError("Failed to read stream")
}))
}
}

View File

@@ -76,6 +76,7 @@ pub async fn startup(
) -> std::io::Result<()> {
println!("Starting server on {}:{}", host, port);
println!("Worker URLs: {:?}", worker_urls);
println!("Policy Config: {:?}", policy_config);
// Create client once with configuration
let client = reqwest::Client::builder()

File diff suppressed because it is too large Load Diff