From 530ff541cf272956ad629a3703ecda80ff68fc63 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 4 Nov 2024 10:56:52 -0800 Subject: [PATCH] [router] Impl radix tree and set up CI (#1893) Co-authored-by: Lianmin Zheng --- .github/workflows/pr-test-rust.yml | 39 ++++++ rust/Cargo.lock | 2 +- rust/Cargo.toml | 10 +- rust/readme.md | 15 +++ rust/src/lib.rs | 13 +- rust/src/main.rs | 7 +- rust/src/router.rs | 15 ++- rust/src/server.rs | 99 +++++++--------- rust/src/tree.rs | 184 +++++++++++++++++++++++++++++ rust/tests/test_tree.rs | 131 ++++++++++++++++++++ scripts/ci_install_rust.sh | 15 +++ 11 files changed, 458 insertions(+), 72 deletions(-) create mode 100644 .github/workflows/pr-test-rust.yml create mode 100644 rust/src/tree.rs create mode 100644 rust/tests/test_tree.rs create mode 100644 scripts/ci_install_rust.sh diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml new file mode 100644 index 000000000..b7e294d28 --- /dev/null +++ b/.github/workflows/pr-test-rust.yml @@ -0,0 +1,39 @@ +name: PR Test (Rust) + +on: + push: + branches: [ main ] + paths: + - "rust/*" + pull_request: + branches: [ main ] + paths: + - "rust/*" + workflow_dispatch: + +concurrency: + group: pr-test-rust-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-test-rust: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_rust.sh + - name: Run fmt + run: | + source "$HOME/.cargo/env" + cd rust/ + cargo fmt -- --check + - name: Run test + timeout-minutes: 20 + run: | + source "$HOME/.cargo/env" + cd rust/ + cargo test \ No newline at end of file diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ffabb4253..44a5a198b 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1607,7 +1607,7 @@ dependencies = [ ] [[package]] -name = "sglang-router" +name = "sglang_router" version = "0.0.0" dependencies = [ "actix-web", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1daed7dd5..794ab1f7c 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,15 +1,17 @@ [package] -name = "sglang-router" +name = "sglang_router" version = "0.0.0" edition = "2021" [[bin]] -name = "router" +name = "sglang_router" path = "src/main.rs" [lib] -name = "router" -crate-type = ["cdylib"] +name = "sglang_router" +# Pure Rust library: Just omit crate-type (defaults to rlib) +# Python/C binding + Rust library: Use ["cdylib", "rlib"] +crate-type = ["cdylib", "rlib"] [dependencies] actix-web = "4.0" diff --git a/rust/readme.md b/rust/readme.md index f73dd71ca..7d10eed75 100644 --- a/rust/readme.md +++ b/rust/readme.md @@ -74,4 +74,19 @@ python -m sglang.launch_server \ $ cargo build --release $ maturin build -i /usr/bin/python $ pip install +``` + + +### Development + +1. Run test + +``` +$ cargo test +``` + +2. Run lint + +``` +$ cargo fmt ``` \ No newline at end of file diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 75444555d..26e43bb8a 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,6 +1,7 @@ use pyo3::prelude::*; -mod server; pub mod router; +mod server; +pub mod tree; // Python binding #[pyclass] @@ -8,7 +9,7 @@ struct Router { host: String, port: u16, worker_urls: Vec, - policy: String + policy: String, } #[pymethods] @@ -19,7 +20,7 @@ impl Router { host, port, worker_urls, - policy + policy, } } @@ -30,7 +31,9 @@ impl Router { let policy = self.policy.clone(); actix_web::rt::System::new().block_on(async move { - server::startup(host, port, worker_urls, policy).await.unwrap(); + server::startup(host, port, worker_urls, policy) + .await + .unwrap(); }); Ok(()) @@ -42,4 +45,4 @@ impl Router { fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) -} \ No newline at end of file +} diff --git a/rust/src/main.rs b/rust/src/main.rs index a7566c9a5..851a85ae5 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -1,9 +1,10 @@ // src/main.rs -use clap::Parser; use clap::builder::PossibleValuesParser; +use clap::Parser; // declare child modules -mod server; mod router; +mod server; +mod tree; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -25,4 +26,4 @@ struct Args { async fn main() -> std::io::Result<()> { let args = Args::parse(); server::startup(args.host, args.port, args.worker_urls, args.policy).await -} \ No newline at end of file +} diff --git a/rust/src/router.rs b/rust/src/router.rs index 064133821..9d42cc13f 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -34,7 +34,9 @@ impl Router for RoundRobinRouter { return None; } // Use relaxed because operation order doesn't matter in round robin - let index = self.current_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + let index = self + .current_index + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) % self.worker_urls.len(); Some(self.worker_urls[index].clone()) } @@ -62,11 +64,11 @@ impl RandomRouter { impl Router for RandomRouter { fn select(&self) -> Option { use rand::seq::SliceRandom; - + if self.worker_urls.is_empty() { return None; } - + self.worker_urls.choose(&mut rand::thread_rng()).cloned() } @@ -83,6 +85,9 @@ pub fn create_router(worker_urls: Vec, policy: String) -> Box Box::new(RandomRouter::new(worker_urls)), "round_robin" => Box::new(RoundRobinRouter::new(worker_urls)), - _ => panic!("Unknown routing policy: {}. The available policies are 'random' and 'round_robin'", policy), + _ => panic!( + "Unknown routing policy: {}. The available policies are 'random' and 'round_robin'", + policy + ), } -} \ No newline at end of file +} diff --git a/rust/src/server.rs b/rust/src/server.rs index a534d27e7..1c6c515b4 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,10 +1,9 @@ -use actix_web::{get, post, web, App, HttpServer, HttpResponse, HttpRequest, Responder}; +use crate::router::create_router; +use crate::router::Router; +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use bytes::Bytes; use futures_util::StreamExt; -use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; -use crate::router::Router; -use crate::router::create_router; - #[derive(Debug)] pub struct AppState { @@ -12,89 +11,77 @@ pub struct AppState { client: reqwest::Client, } - -impl AppState -{ +impl AppState { pub fn new(worker_urls: Vec, policy: String, client: reqwest::Client) -> Self { // Create router based on policy let router = create_router(worker_urls, policy); - - Self { - router, - client, - } + + Self { router, client } } } #[get("/v1/models")] -async fn v1_model( - data: web::Data, -) -> impl Responder { - let worker_url= match data.router.get_first() { +async fn v1_model(data: web::Data) -> impl Responder { + let worker_url = match data.router.get_first() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), }; // Use the shared client - match data.client - .get(&format!("{}/v1/models", worker_url)) + match data + .client + .get(format!("{}/v1/models", worker_url)) .send() - .await + .await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + // print the status println!("Worker URL: {}, Status: {}", worker_url, status); match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(_) => HttpResponse::InternalServerError().finish(), } - }, + } Err(_) => HttpResponse::InternalServerError().finish(), } } #[get("/get_model_info")] -async fn get_model_info( - data: web::Data, -) -> impl Responder { - let worker_url= match data.router.get_first() { +async fn get_model_info(data: web::Data) -> impl Responder { + let worker_url = match data.router.get_first() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), }; // Use the shared client - match data.client - .get(&format!("{}/get_model_info", worker_url)) + match data + .client + .get(format!("{}/get_model_info", worker_url)) .send() - .await + .await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + // print the status println!("Worker URL: {}, Status: {}", worker_url, status); match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(_) => HttpResponse::InternalServerError().finish(), } - }, + } Err(_) => HttpResponse::InternalServerError().finish(), } } // no deser and ser, just forward and return #[post("/generate")] -async fn generate( - req: HttpRequest, - body: Bytes, - data: web::Data, -) -> impl Responder { - +async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { // create a router struct // TODO: use router abstraction for different policy - let worker_url= match data.router.select() { + let worker_url = match data.router.select() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), }; @@ -104,18 +91,19 @@ async fn generate( .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); - let res = match data.client - .post(&format!("{}/generate", worker_url)) + let res = match data + .client + .post(format!("{}/generate", worker_url)) .header( - "Content-Type", + "Content-Type", req.headers() .get("Content-Type") .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json") + .unwrap_or("application/json"), ) .body(body.to_vec()) .send() - .await + .await { Ok(res) => res, Err(_) => return HttpResponse::InternalServerError().finish(), @@ -128,18 +116,25 @@ async fn generate( match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(_) => HttpResponse::InternalServerError().finish(), - } + } } else { HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming(res.bytes_stream().map(|b| match b { Ok(b) => Ok::<_, actix_web::Error>(b), - Err(_) => Err(actix_web::Error::from(actix_web::error::ErrorInternalServerError("Failed to read stream"))), + Err(_) => Err(actix_web::error::ErrorInternalServerError( + "Failed to read stream", + )), })) } } -pub async fn startup(host: String, port: u16, worker_urls: Vec, routing_policy: String) -> std::io::Result<()> { +pub async fn startup( + host: String, + port: u16, + worker_urls: Vec, + routing_policy: String, +) -> std::io::Result<()> { println!("Starting server on {}:{}", host, port); println!("Worker URLs: {:?}", worker_urls); @@ -149,11 +144,7 @@ pub async fn startup(host: String, port: u16, worker_urls: Vec, routing_ .expect("Failed to create HTTP client"); // Store both worker_urls and client in AppState - let app_state = web::Data::new(AppState::new( - worker_urls, - routing_policy, - client, - )); + let app_state = web::Data::new(AppState::new(worker_urls, routing_policy, client)); HttpServer::new(move || { App::new() @@ -165,4 +156,4 @@ pub async fn startup(host: String, port: u16, worker_urls: Vec, routing_ .bind((host, port))? .run() .await -} \ No newline at end of file +} diff --git a/rust/src/tree.rs b/rust/src/tree.rs new file mode 100644 index 000000000..27f1db8c2 --- /dev/null +++ b/rust/src/tree.rs @@ -0,0 +1,184 @@ +use std::collections::HashMap; +use std::mem; + +#[derive(Clone)] +pub struct Node { + pub children: HashMap, // the key is first id of the child because each child must have unique first id + pub ids: Vec, + pub count: usize, +} + +pub struct RadixTree { + pub root: Node, +} + +fn common_prefix_len(a: &[usize], b: &[usize]) -> usize { + let mut i = 0; + while i < a.len() && i < b.len() && a[i] == b[i] { + i += 1; + } + i +} + +impl Default for RadixTree { + fn default() -> Self { + Self::new() + } +} + +impl RadixTree { + pub fn new() -> Self { + RadixTree { + root: Node { + children: HashMap::new(), + ids: Vec::new(), + count: 0, + }, + } + } + + pub fn insert(&mut self, input_ids: &[usize]) { + let mut curr = &mut self.root; + curr.count += 1; + + let mut curr_idx = 0; + let input_ids_len = input_ids.len(); + + while curr_idx < input_ids_len { + let first_id = &input_ids[curr_idx]; + // TODO: changing this get_mut causes error + if curr.children.contains_key(first_id) { + let child = curr.children.get_mut(first_id).unwrap(); + + let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); + + if prefix_len == child.ids.len() { + // move curr to child + curr = child; + curr.count += 1; + curr_idx += prefix_len; + } else { + // split child + // [child]->... => [child]->[new child]->... + let new_child = Node { + // to avoid clone: replace child.children with default value (empty vector) and return the original value + children: mem::take(&mut child.children), + ids: child.ids[prefix_len..].to_vec(), + count: child.count, + }; + + child.ids = child.ids[..prefix_len].to_vec(); + child.children = HashMap::new(); + child.children.insert(new_child.ids[0], new_child); + + curr = child; + curr.count += 1; + curr_idx += prefix_len; + } + } else { + // create new child + let new_child = Node { + children: HashMap::new(), + ids: input_ids[curr_idx..].to_vec(), + count: 0, + }; + + let first_id = new_child.ids[0]; + curr.children.insert(first_id, new_child); + + curr = curr.children.get_mut(&first_id).unwrap(); + curr.count += 1; + curr_idx = input_ids_len; + } + } + } + + pub fn prefix_match<'a>(&self, input_ids: &'a [usize]) -> &'a [usize] { + let mut curr = &self.root; + + let mut curr_idx = 0; + let input_ids_len = input_ids.len(); + + while curr_idx < input_ids_len { + match curr.children.get(&input_ids[curr_idx]) { + Some(child) => { + let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); + + if prefix_len == child.ids.len() { + curr_idx += prefix_len; + curr = child; + } else { + curr_idx += prefix_len; + break; + } + } + None => { + break; + } + } + } + + &input_ids[..curr_idx] + } + + pub fn delete(&mut self, input_ids: &[usize]) { + let mut curr = &mut self.root; + curr.count -= 1; + + let mut curr_idx = 0; + let input_ids_len = input_ids.len(); + + while curr_idx < input_ids_len { + let first_id = &input_ids[curr_idx]; + + if curr.children.contains_key(first_id) { + let child = curr.children.get(first_id).unwrap(); + + let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); + + if prefix_len == child.ids.len() { + if child.count == 1 { + // If count will become 0, remove the child + let child = curr.children.get_mut(first_id).unwrap(); + child.count -= 1; + curr.children.remove(first_id); + break; + } else { + // Otherwise decrement count and continue + let child = curr.children.get_mut(first_id).unwrap(); + + child.count -= 1; + curr = child; + curr_idx += prefix_len; + } + } else { + panic!("No match found for {:?}", input_ids); + } + } else { + panic!("No match found for {:?}", input_ids); + } + } + } + + // for debug + pub fn pretty_print(&self) { + println!("RadixTree:"); + Self::print_node(&self.root, String::from("")); + } + + fn print_node(node: &Node, prefix: String) { + // Print current node info with "count" word + println!("{}└── {:?} (count: {})", prefix, node.ids, node.count); + + // Print children with proper prefixes + for (i, child) in node.children.values().enumerate() { + let is_last = i == node.children.len() - 1; + let child_prefix = if is_last { + format!("{} ", prefix) // Add space for last child + } else { + format!("{}│ ", prefix) // Add vertical line for other children + }; + Self::print_node(child, child_prefix); + } + } +} diff --git a/rust/tests/test_tree.rs b/rust/tests/test_tree.rs new file mode 100644 index 000000000..c9e453c10 --- /dev/null +++ b/rust/tests/test_tree.rs @@ -0,0 +1,131 @@ +use sglang_router::tree::RadixTree; + +#[test] +fn test_new_tree() { + let tree = RadixTree::new(); + assert_eq!(tree.root.count, 0); + assert!(tree.root.children.is_empty()); + assert!(tree.root.ids.is_empty()); +} + +#[test] +fn test_single_insertion() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3]); + + assert_eq!(tree.root.count, 1); + assert_eq!(tree.root.children.len(), 1); + assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]); + assert_eq!(tree.root.children[&1].count, 1); +} + +#[test] +fn test_multiple_insertions_no_split() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3]); + tree.insert(&[4, 5, 6]); + + assert_eq!(tree.root.count, 2); + assert_eq!(tree.root.children.len(), 2); + assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]); + assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]); +} + +#[test] +fn test_insertion_with_split() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3, 4]); + tree.insert(&[1, 2, 5, 6]); + + assert_eq!(tree.root.count, 2); + assert_eq!(tree.root.children.len(), 1); + assert_eq!(tree.root.children[&1].ids, vec![1, 2]); + assert_eq!(tree.root.children[&1].children.len(), 2); + assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]); + assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]); +} + +#[test] +fn test_prefix_match_exact() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3, 4]); + + assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]); +} + +#[test] +fn test_prefix_match_partial() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3, 4]); + + assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]); + assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]); + assert_eq!(tree.prefix_match(&[1, 5]), &[1]); +} + +#[test] +fn test_prefix_match_no_match() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3, 4]); + let empty_slices: &[usize] = &[]; + assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices); +} + +#[test] +fn test_delete_leaf() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3]); + tree.delete(&[1, 2, 3]); + + assert_eq!(tree.root.count, 0); + assert_eq!(tree.root.children.len(), 0); +} + +#[test] +fn test_delete_with_siblings() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3]); + tree.insert(&[1, 2, 4]); + tree.delete(&[1, 2, 3]); + + assert_eq!(tree.root.count, 1); + assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]); +} + +#[test] +fn test_multiple_operations() { + let mut tree = RadixTree::new(); + + // Insert several paths + tree.insert(&[1, 2, 3]); + tree.insert(&[1, 2, 4]); + tree.insert(&[1, 5, 6]); + + // Verify structure + assert_eq!(tree.root.count, 3); + assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]); + assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]); + assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]); + + // Delete and verify + tree.delete(&[1, 2, 3]); + assert_eq!(tree.root.count, 2); + assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix +} + +#[test] +#[should_panic(expected = "No match found")] +fn test_delete_nonexistent() { + let mut tree = RadixTree::new(); + tree.insert(&[1, 2, 3]); + tree.delete(&[4, 5, 6]); // Should panic +} + +#[test] +fn test_empty_input() { + let mut tree = RadixTree::new(); + let empty_slice: &[usize] = &[]; + tree.insert(empty_slice); + assert_eq!(tree.prefix_match(empty_slice), empty_slice); + tree.delete(empty_slice); // Should not panic +} diff --git a/scripts/ci_install_rust.sh b/scripts/ci_install_rust.sh new file mode 100644 index 000000000..23bcc3bef --- /dev/null +++ b/scripts/ci_install_rust.sh @@ -0,0 +1,15 @@ +# these are required for actix +apt-get update +apt-get install -y libssl-dev pkg-config + +# Install rustup (Rust installer and version manager) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + + +# Follow the installation prompts, then reload your shell +. "$HOME/.cargo/env" +source $HOME/.cargo/env + +# Verify installation +rustc --version +cargo --version \ No newline at end of file