[rust] cache-aware DP - approx tree (#1934)
This commit is contained in:
@@ -1,37 +1,86 @@
|
||||
// Python Binding
|
||||
use pyo3::prelude::*;
|
||||
pub mod router;
|
||||
mod server;
|
||||
pub mod tree;
|
||||
|
||||
// Python binding
|
||||
#[pyclass(eq)]
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
ApproxTree,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Router {
|
||||
host: String,
|
||||
port: u16,
|
||||
worker_urls: Vec<String>,
|
||||
policy: String,
|
||||
policy: PolicyType,
|
||||
tokenizer_path: Option<String>,
|
||||
cache_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Router {
|
||||
#[new]
|
||||
fn new(host: String, port: u16, worker_urls: Vec<String>, policy: String) -> Self {
|
||||
Router {
|
||||
#[pyo3(signature = (
|
||||
worker_urls,
|
||||
policy = PolicyType::RoundRobin,
|
||||
host = String::from("127.0.0.1"),
|
||||
port = 3001,
|
||||
tokenizer_path = None,
|
||||
cache_threshold = Some(0.50)
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
host: String,
|
||||
port: u16,
|
||||
tokenizer_path: Option<String>,
|
||||
cache_threshold: Option<f32>,
|
||||
) -> PyResult<Self> {
|
||||
// Validate required parameters for approx_tree policy
|
||||
if matches!(policy, PolicyType::ApproxTree) {
|
||||
if tokenizer_path.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"tokenizer_path is required for approx_tree policy",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Router {
|
||||
host,
|
||||
port,
|
||||
worker_urls,
|
||||
policy,
|
||||
}
|
||||
tokenizer_path,
|
||||
cache_threshold,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
let host = self.host.clone();
|
||||
let port = self.port;
|
||||
let worker_urls = self.worker_urls.clone();
|
||||
let policy = self.policy.clone();
|
||||
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig,
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
|
||||
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path: self
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.expect("tokenizer_path is required for approx_tree policy"),
|
||||
cache_threshold: self
|
||||
.cache_threshold
|
||||
.expect("cache_threshold is required for approx_tree policy"),
|
||||
},
|
||||
};
|
||||
|
||||
actix_web::rt::System::new().block_on(async move {
|
||||
server::startup(host, port, worker_urls, policy)
|
||||
server::startup(host, port, worker_urls, policy_config)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
@@ -40,9 +89,9 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// python usage: `from sglang_router import Router`
|
||||
#[pymodule]
|
||||
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PolicyType>()?;
|
||||
m.add_class::<Router>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,29 +1,87 @@
|
||||
// src/main.rs
|
||||
use clap::builder::PossibleValuesParser;
|
||||
use clap::Parser;
|
||||
use clap::ValueEnum;
|
||||
// declare child modules
|
||||
mod router;
|
||||
mod server;
|
||||
mod tree;
|
||||
|
||||
use crate::router::PolicyConfig;
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
ApproxTree,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "127.0.0.1",
|
||||
help = "Host address to bind the server to"
|
||||
)]
|
||||
host: String,
|
||||
|
||||
#[arg(long, default_value_t = 3001)]
|
||||
#[arg(long, default_value_t = 3001, help = "Port number to listen on")]
|
||||
port: u16,
|
||||
|
||||
#[arg(long, value_delimiter = ',')]
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
help = "Comma-separated list of worker URLs to distribute requests to"
|
||||
)]
|
||||
worker_urls: Vec<String>,
|
||||
|
||||
#[arg(long, default_value = "round_robin", value_parser = PossibleValuesParser::new(&["round_robin", "random"]))]
|
||||
policy: String,
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = PolicyType::RoundRobin,
|
||||
value_enum,
|
||||
help = "Load balancing policy to use: random, round_robin, or approx_tree"
|
||||
)]
|
||||
policy: PolicyType,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "approx_tree"),
|
||||
help = "Path to the tokenizer file, required when using approx_tree policy"
|
||||
)]
|
||||
tokenizer_path: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "0.50",
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "approx_tree"),
|
||||
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
|
||||
)]
|
||||
cache_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn get_policy_config(&self) -> PolicyConfig {
|
||||
match self.policy {
|
||||
PolicyType::Random => PolicyConfig::RandomConfig,
|
||||
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
|
||||
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path: self
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.expect("tokenizer_path is required for approx_tree policy"),
|
||||
cache_threshold: self
|
||||
.cache_threshold
|
||||
.expect("cache_threshold is required for approx_tree policy"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let args = Args::parse();
|
||||
server::startup(args.host, args.port, args.worker_urls, args.policy).await
|
||||
let policy_config = args.get_policy_config();
|
||||
server::startup(args.host, args.port, args.worker_urls, policy_config).await
|
||||
}
|
||||
|
||||
@@ -1,38 +1,90 @@
|
||||
use crate::tree::RadixTree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::TryStreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Router {
|
||||
RoundRobin {
|
||||
worker_urls: Vec<String>,
|
||||
current_index: std::sync::atomic::AtomicUsize,
|
||||
current_index: AtomicUsize,
|
||||
},
|
||||
Random {
|
||||
worker_urls: Vec<String>,
|
||||
},
|
||||
ApproxTree {
|
||||
worker_urls: Vec<String>,
|
||||
// TODO: don't lock the whole tree
|
||||
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
|
||||
tokenizer: Tokenizer,
|
||||
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
|
||||
cache_threshold: f32,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum PolicyConfig {
|
||||
RandomConfig,
|
||||
RoundRobinConfig,
|
||||
ApproxTreeConfig {
|
||||
tokenizer_path: String,
|
||||
cache_threshold: f32,
|
||||
},
|
||||
}
|
||||
|
||||
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
|
||||
// 1. convert body to json
|
||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||
// 2. get the text field
|
||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
// 3. tokenize the text field
|
||||
let tokens = tokenizer.encode(text, false).unwrap();
|
||||
|
||||
tokens.get_ids().to_vec()
|
||||
}
|
||||
|
||||
impl Router {
|
||||
pub fn new(worker_urls: Vec<String>, policy: String) -> Self {
|
||||
match policy.to_lowercase().as_str() {
|
||||
"random" => Router::Random { worker_urls },
|
||||
"round_robin" => Router::RoundRobin {
|
||||
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
|
||||
match policy_config {
|
||||
PolicyConfig::RandomConfig => Router::Random { worker_urls },
|
||||
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||
},
|
||||
_ => panic!(
|
||||
"Unknown routing policy: {}. The available policies are 'random' and 'round_robin'",
|
||||
policy
|
||||
),
|
||||
PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path,
|
||||
cache_threshold,
|
||||
} => {
|
||||
let mut url_to_tree = HashMap::new();
|
||||
let mut url_to_count = HashMap::new();
|
||||
|
||||
for url in &worker_urls {
|
||||
url_to_tree.insert(url.clone(), RadixTree::new());
|
||||
url_to_count.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
Router::ApproxTree {
|
||||
worker_urls,
|
||||
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
|
||||
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
|
||||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||
url_to_count: Arc::new(Mutex::new(url_to_count)),
|
||||
cache_threshold,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_first(&self) -> Option<String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } => {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::ApproxTree { worker_urls, .. } => {
|
||||
if worker_urls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -48,26 +100,96 @@ impl Router {
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
) -> HttpResponse {
|
||||
let mut input_ids: Vec<u32> = Vec::new();
|
||||
if let Router::ApproxTree { tokenizer, .. } = self {
|
||||
input_ids = get_token_ids_from_request(&body, tokenizer);
|
||||
}
|
||||
|
||||
let worker_url = match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index,
|
||||
} => {
|
||||
current_index
|
||||
let idx = current_index
|
||||
.fetch_update(
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
|x| Some((x + 1) % worker_urls.len()),
|
||||
)
|
||||
.expect_err("Error updating index in round robin");
|
||||
.unwrap();
|
||||
|
||||
&worker_urls[current_index.load(std::sync::atomic::Ordering::SeqCst)]
|
||||
worker_urls[idx].clone()
|
||||
}
|
||||
|
||||
Router::Random { worker_urls } => {
|
||||
&worker_urls[rand::random::<usize>() % worker_urls.len()]
|
||||
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
|
||||
}
|
||||
|
||||
Router::ApproxTree {
|
||||
worker_urls,
|
||||
url_to_tree,
|
||||
url_to_count,
|
||||
cache_threshold,
|
||||
..
|
||||
} => {
|
||||
// TODO: pipeline the locks. Release one earlier.
|
||||
|
||||
let mut max_matched_rate = 0.0;
|
||||
let mut max_matched_idx = 0;
|
||||
|
||||
let locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||
|
||||
// 1. Find the highest matched worker
|
||||
for (i, url) in worker_urls.iter().enumerate() {
|
||||
let tree = locked_url_to_tree.get(url).unwrap();
|
||||
let matched = tree.prefix_match(&input_ids[..]).len();
|
||||
let matched_rate = matched as f32 / input_ids.len() as f32;
|
||||
|
||||
if matched_rate > max_matched_rate {
|
||||
max_matched_rate = matched_rate;
|
||||
max_matched_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
|
||||
if max_matched_rate > *cache_threshold {
|
||||
worker_urls[max_matched_idx].clone()
|
||||
} else {
|
||||
// pick the shortest queue from url_to_count
|
||||
let locked_url_to_count = url_to_count.lock().unwrap();
|
||||
|
||||
let mut min_count = std::usize::MAX;
|
||||
let mut min_count_id = 0;
|
||||
|
||||
for (i, url) in worker_urls.iter().enumerate() {
|
||||
let count = locked_url_to_count.get(url).unwrap();
|
||||
if *count < min_count {
|
||||
min_count = *count;
|
||||
min_count_id = i;
|
||||
}
|
||||
}
|
||||
|
||||
worker_urls[min_count_id].clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Router::ApproxTree {
|
||||
url_to_tree,
|
||||
url_to_count,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
// Insert input_ids to the tree
|
||||
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
|
||||
selected_tree.insert(&input_ids[..]);
|
||||
|
||||
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
// Check if client requested streaming
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
@@ -94,11 +216,19 @@ impl Router {
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !is_stream {
|
||||
// TODO: do the correction on the tree based on the cached input_ids
|
||||
if let Router::ApproxTree { url_to_count, .. } = self {
|
||||
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||
*count -= 1;
|
||||
}
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||
}
|
||||
} else {
|
||||
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(res.bytes_stream().map_err(|_| {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
||||
use bytes::Bytes;
|
||||
@@ -9,9 +10,13 @@ pub struct AppState {
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(worker_urls: Vec<String>, policy: String, client: reqwest::Client) -> Self {
|
||||
pub fn new(
|
||||
worker_urls: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Self {
|
||||
// Create router based on policy
|
||||
let router = Router::new(worker_urls, policy);
|
||||
let router = Router::new(worker_urls, policy_config);
|
||||
|
||||
Self { router, client }
|
||||
}
|
||||
@@ -40,7 +45,6 @@ async fn forward_request(
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
|
||||
// TODO: extract forward_to_route
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
@@ -59,7 +63,6 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
||||
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
|
||||
}
|
||||
|
||||
// no deser and ser, just forward and return
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.dispatch(&data.client, req, body).await
|
||||
@@ -69,7 +72,7 @@ pub async fn startup(
|
||||
host: String,
|
||||
port: u16,
|
||||
worker_urls: Vec<String>,
|
||||
routing_policy: String,
|
||||
policy_config: PolicyConfig,
|
||||
) -> std::io::Result<()> {
|
||||
println!("Starting server on {}:{}", host, port);
|
||||
println!("Worker URLs: {:?}", worker_urls);
|
||||
@@ -80,7 +83,7 @@ pub async fn startup(
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// Store both worker_urls and client in AppState
|
||||
let app_state = web::Data::new(AppState::new(worker_urls, routing_policy, client));
|
||||
let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config));
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
pub children: HashMap<usize, Node>, // the key is first id of the child because each child must have unique first id
|
||||
pub ids: Vec<usize>,
|
||||
pub count: usize,
|
||||
pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
|
||||
pub ids: Vec<u32>,
|
||||
pub count: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTree {
|
||||
pub root: Node,
|
||||
}
|
||||
|
||||
fn common_prefix_len(a: &[usize], b: &[usize]) -> usize {
|
||||
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
|
||||
let mut i = 0;
|
||||
while i < a.len() && i < b.len() && a[i] == b[i] {
|
||||
i += 1;
|
||||
@@ -37,7 +38,7 @@ impl RadixTree {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, input_ids: &[usize]) {
|
||||
pub fn insert(&mut self, input_ids: &[u32]) {
|
||||
let mut curr = &mut self.root;
|
||||
curr.count += 1;
|
||||
|
||||
@@ -93,7 +94,7 @@ impl RadixTree {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prefix_match<'a>(&self, input_ids: &'a [usize]) -> &'a [usize] {
|
||||
pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
|
||||
let mut curr = &self.root;
|
||||
|
||||
let mut curr_idx = 0;
|
||||
@@ -121,7 +122,7 @@ impl RadixTree {
|
||||
&input_ids[..curr_idx]
|
||||
}
|
||||
|
||||
pub fn delete(&mut self, input_ids: &[usize]) {
|
||||
pub fn delete(&mut self, input_ids: &[u32]) {
|
||||
let mut curr = &mut self.root;
|
||||
curr.count -= 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user