[rust] cache-aware DP - approx tree (#1934)
This commit is contained in:
103
benchmark/multi_turn_chat/long_prompt_multi_turn.py
Normal file
103
benchmark/multi_turn_chat/long_prompt_multi_turn.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenize
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt(tokenizer, token_num):
|
||||||
|
all_available_tokens = list(tokenizer.get_vocab().values())
|
||||||
|
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
||||||
|
ret = tokenizer.decode(selected_tokens)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def gen_arguments(args, tokenizer):
|
||||||
|
multi_qas = [
|
||||||
|
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
|
||||||
|
for _ in range(args.num_qa)
|
||||||
|
]
|
||||||
|
for i in range(args.num_qa):
|
||||||
|
qas = multi_qas[i]["qas"]
|
||||||
|
for j in range(args.turns):
|
||||||
|
qas.append(
|
||||||
|
{
|
||||||
|
"prompt": gen_prompt(tokenizer, args.len_q),
|
||||||
|
"new_tokens": args.len_a,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return multi_qas
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def multi_turns(s, system_prompt, qas):
|
||||||
|
s += system_prompt
|
||||||
|
|
||||||
|
for qa in qas:
|
||||||
|
s += qa["prompt"]
|
||||||
|
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
multi_qas = gen_arguments(args, tokenizer)
|
||||||
|
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
states = multi_turns.run_batch(
|
||||||
|
multi_qas,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "multi_turn_system_prompt_chat",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_qa,
|
||||||
|
"num_turns": args.turns,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--turns", type=int, default=8)
|
||||||
|
parser.add_argument("--num-qa", type=int, default=128)
|
||||||
|
parser.add_argument("--system-prompt-len", type=int, default=2048)
|
||||||
|
parser.add_argument("--len-q", type=int, default=32)
|
||||||
|
parser.add_argument("--len-a", type=int, default=128)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
main(args)
|
||||||
1021
rust/Cargo.lock
generated
1021
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -21,5 +21,6 @@ bytes = "1.8.0"
|
|||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.8", features = ["stream"] }
|
reqwest = { version = "0.12.8", features = ["stream"] }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
serde_json = "=1.0.1"
|
serde_json = "1.0"
|
||||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||||
|
tokenizers = { version = "0.20.3", features = ["http"] }
|
||||||
|
|||||||
156
rust/py_src/dp_demo.py
Normal file
156
rust/py_src/dp_demo.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from sglang_router import PolicyType, Router
|
||||||
|
|
||||||
|
# Global processes list for cleanup
|
||||||
|
_processes: List[subprocess.Popen] = []
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_processes(signum=None, frame=None):
|
||||||
|
"""Cleanup function to kill all worker processes."""
|
||||||
|
print("\nCleaning up processes...")
|
||||||
|
for process in _processes:
|
||||||
|
try:
|
||||||
|
# Kill the entire process group
|
||||||
|
pgid = os.getpgid(process.pid)
|
||||||
|
os.killpg(pgid, signal.SIGKILL)
|
||||||
|
process.wait()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# Register signal handlers
|
||||||
|
signal.signal(signal.SIGINT, cleanup_processes)
|
||||||
|
signal.signal(signal.SIGTERM, cleanup_processes)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="localhost", help="Host address to bind the server"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=30000, help="Base port number for workers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dp",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of worker processes (degree of parallelism)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path", type=str, required=True, help="Path to the model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-tokenizer-path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the local tokenizer",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
|
||||||
|
"""Launch all worker processes concurrently using subprocess."""
|
||||||
|
processes = []
|
||||||
|
worker_urls = []
|
||||||
|
|
||||||
|
# Launch each worker process
|
||||||
|
for i in range(args.dp):
|
||||||
|
port = args.port + i
|
||||||
|
url = f"http://{args.host}:{port}"
|
||||||
|
worker_urls.append(url)
|
||||||
|
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
|
||||||
|
# We don't
|
||||||
|
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
|
||||||
|
print(command)
|
||||||
|
process = subprocess.Popen(command, shell=True)
|
||||||
|
processes.append(process)
|
||||||
|
_processes.append(process) # Add to global list for cleanup
|
||||||
|
|
||||||
|
return processes, worker_urls
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
|
||||||
|
"""Block until all workers are healthy or timeout is reached."""
|
||||||
|
start_time = time.time()
|
||||||
|
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
print("checking healthiness...")
|
||||||
|
all_healthy = True
|
||||||
|
|
||||||
|
for url in worker_urls:
|
||||||
|
if not healthy_workers[url]: # Only check workers that aren't healthy yet
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{url}/health")
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"Worker at {url} is healthy")
|
||||||
|
healthy_workers[url] = True
|
||||||
|
else:
|
||||||
|
all_healthy = False
|
||||||
|
except requests.RequestException:
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
if all_healthy:
|
||||||
|
print("All workers are healthy!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# If we get here, we've timed out
|
||||||
|
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
|
||||||
|
print(f"Timeout waiting for workers: {unhealthy_workers}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to launch the router and workers."""
|
||||||
|
args = parse_args()
|
||||||
|
processes = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Launch all workers concurrently
|
||||||
|
processes, worker_urls = launch_workers(args)
|
||||||
|
|
||||||
|
# Block until all workers are healthy
|
||||||
|
if not wait_for_healthy_workers(worker_urls):
|
||||||
|
raise RuntimeError("Failed to start all workers")
|
||||||
|
|
||||||
|
# Initialize and start the router
|
||||||
|
router = Router(
|
||||||
|
worker_urls=worker_urls,
|
||||||
|
policy=PolicyType.ApproxTree,
|
||||||
|
tokenizer_path=args.local_tokenizer_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Starting router...")
|
||||||
|
router.start()
|
||||||
|
|
||||||
|
# Keep the main process running
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nShutting down...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
finally:
|
||||||
|
# Cleanup: Kill all worker processes
|
||||||
|
if processes:
|
||||||
|
for process in processes:
|
||||||
|
process.kill()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
12
rust/py_src/main.py
Normal file
12
rust/py_src/main.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from sglang_router import PolicyType, Router
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
worker_urls=[
|
||||||
|
"http://localhost:30000",
|
||||||
|
"http://localhost:30001",
|
||||||
|
],
|
||||||
|
policy=PolicyType.ApproxTree,
|
||||||
|
tokenizer_path="/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693/tokenizer.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
router.start()
|
||||||
@@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances.
|
SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
1. `src/`: rust impl of the router
|
||||||
|
2. `py_src/`: lightweight python interafce on top of rust python binding. This will be published as `sglang-router` pypi package
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
WIP. Ideally just
|
WIP. Ideally just
|
||||||
@@ -83,6 +88,23 @@ $ maturin develop
|
|||||||
🛠 Installed sglang_router-0.0.0
|
🛠 Installed sglang_router-0.0.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4. Alternatively, if you don't want to create a venv, you can also build the binding as a wheel and install it
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ maturin build --interpreter python
|
||||||
|
...
|
||||||
|
Compiling pyo3 v0.22.6
|
||||||
|
Compiling pyo3-macros v0.22.6
|
||||||
|
Compiling sglang_router v0.0.0 (/home/jobuser/sglang/rust)
|
||||||
|
Finished `dev` profile [unoptimized + debuginfo] target(s) in 9.67s
|
||||||
|
🖨 Copied external shared libraries to package sglang_router.libs directory:
|
||||||
|
/usr/lib/libssl.so.1.1.1k
|
||||||
|
/usr/lib/libcrypto.so.1.1.1k
|
||||||
|
📦 Built wheel for CPython 3.10 to <wheel path>
|
||||||
|
|
||||||
|
$ pip install <wheel path>
|
||||||
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. Launch worker instances
|
1. Launch worker instances
|
||||||
|
|||||||
1
rust/sglang
Submodule
1
rust/sglang
Submodule
Submodule rust/sglang added at 760552e068
@@ -1,37 +1,86 @@
|
|||||||
|
// Python Binding
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
mod server;
|
mod server;
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
|
||||||
// Python binding
|
#[pyclass(eq)]
|
||||||
|
#[derive(Clone, PartialEq)]
|
||||||
|
pub enum PolicyType {
|
||||||
|
Random,
|
||||||
|
RoundRobin,
|
||||||
|
ApproxTree,
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct Router {
|
struct Router {
|
||||||
host: String,
|
host: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
policy: String,
|
policy: PolicyType,
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
|
cache_threshold: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Router {
|
impl Router {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(host: String, port: u16, worker_urls: Vec<String>, policy: String) -> Self {
|
#[pyo3(signature = (
|
||||||
Router {
|
worker_urls,
|
||||||
|
policy = PolicyType::RoundRobin,
|
||||||
|
host = String::from("127.0.0.1"),
|
||||||
|
port = 3001,
|
||||||
|
tokenizer_path = None,
|
||||||
|
cache_threshold = Some(0.50)
|
||||||
|
))]
|
||||||
|
fn new(
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
policy: PolicyType,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
|
cache_threshold: Option<f32>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
// Validate required parameters for approx_tree policy
|
||||||
|
if matches!(policy, PolicyType::ApproxTree) {
|
||||||
|
if tokenizer_path.is_none() {
|
||||||
|
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||||
|
"tokenizer_path is required for approx_tree policy",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
worker_urls,
|
worker_urls,
|
||||||
policy,
|
policy,
|
||||||
}
|
tokenizer_path,
|
||||||
|
cache_threshold,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start(&self) -> PyResult<()> {
|
fn start(&self) -> PyResult<()> {
|
||||||
let host = self.host.clone();
|
let host = self.host.clone();
|
||||||
let port = self.port;
|
let port = self.port;
|
||||||
let worker_urls = self.worker_urls.clone();
|
let worker_urls = self.worker_urls.clone();
|
||||||
let policy = self.policy.clone();
|
|
||||||
|
let policy_config = match &self.policy {
|
||||||
|
PolicyType::Random => router::PolicyConfig::RandomConfig,
|
||||||
|
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
|
||||||
|
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
|
||||||
|
tokenizer_path: self
|
||||||
|
.tokenizer_path
|
||||||
|
.clone()
|
||||||
|
.expect("tokenizer_path is required for approx_tree policy"),
|
||||||
|
cache_threshold: self
|
||||||
|
.cache_threshold
|
||||||
|
.expect("cache_threshold is required for approx_tree policy"),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
actix_web::rt::System::new().block_on(async move {
|
actix_web::rt::System::new().block_on(async move {
|
||||||
server::startup(host, port, worker_urls, policy)
|
server::startup(host, port, worker_urls, policy_config)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
@@ -40,9 +89,9 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// python usage: `from sglang_router import Router`
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
|
m.add_class::<PolicyType>()?;
|
||||||
m.add_class::<Router>()?;
|
m.add_class::<Router>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,87 @@
|
|||||||
// src/main.rs
|
// src/main.rs
|
||||||
use clap::builder::PossibleValuesParser;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use clap::ValueEnum;
|
||||||
// declare child modules
|
// declare child modules
|
||||||
mod router;
|
mod router;
|
||||||
mod server;
|
mod server;
|
||||||
mod tree;
|
mod tree;
|
||||||
|
|
||||||
|
use crate::router::PolicyConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, ValueEnum)]
|
||||||
|
pub enum PolicyType {
|
||||||
|
Random,
|
||||||
|
RoundRobin,
|
||||||
|
ApproxTree,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "127.0.0.1",
|
||||||
|
help = "Host address to bind the server to"
|
||||||
|
)]
|
||||||
host: String,
|
host: String,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 3001)]
|
#[arg(long, default_value_t = 3001, help = "Port number to listen on")]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
#[arg(long, value_delimiter = ',')]
|
#[arg(
|
||||||
|
long,
|
||||||
|
value_delimiter = ',',
|
||||||
|
help = "Comma-separated list of worker URLs to distribute requests to"
|
||||||
|
)]
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "round_robin", value_parser = PossibleValuesParser::new(&["round_robin", "random"]))]
|
#[arg(
|
||||||
policy: String,
|
long,
|
||||||
|
default_value_t = PolicyType::RoundRobin,
|
||||||
|
value_enum,
|
||||||
|
help = "Load balancing policy to use: random, round_robin, or approx_tree"
|
||||||
|
)]
|
||||||
|
policy: PolicyType,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
long,
|
||||||
|
requires = "policy",
|
||||||
|
required_if_eq("policy", "approx_tree"),
|
||||||
|
help = "Path to the tokenizer file, required when using approx_tree policy"
|
||||||
|
)]
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "0.50",
|
||||||
|
requires = "policy",
|
||||||
|
required_if_eq("policy", "approx_tree"),
|
||||||
|
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
|
||||||
|
)]
|
||||||
|
cache_threshold: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn get_policy_config(&self) -> PolicyConfig {
|
||||||
|
match self.policy {
|
||||||
|
PolicyType::Random => PolicyConfig::RandomConfig,
|
||||||
|
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
|
||||||
|
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
|
||||||
|
tokenizer_path: self
|
||||||
|
.tokenizer_path
|
||||||
|
.clone()
|
||||||
|
.expect("tokenizer_path is required for approx_tree policy"),
|
||||||
|
cache_threshold: self
|
||||||
|
.cache_threshold
|
||||||
|
.expect("cache_threshold is required for approx_tree policy"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::main]
|
#[actix_web::main]
|
||||||
async fn main() -> std::io::Result<()> {
|
async fn main() -> std::io::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
server::startup(args.host, args.port, args.worker_urls, args.policy).await
|
let policy_config = args.get_policy_config();
|
||||||
|
server::startup(args.host, args.port, args.worker_urls, policy_config).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +1,90 @@
|
|||||||
|
use crate::tree::RadixTree;
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::TryStreamExt;
|
use futures_util::TryStreamExt;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::sync::atomic::AtomicUsize;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Router {
|
pub enum Router {
|
||||||
RoundRobin {
|
RoundRobin {
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
current_index: std::sync::atomic::AtomicUsize,
|
current_index: AtomicUsize,
|
||||||
},
|
},
|
||||||
Random {
|
Random {
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
},
|
},
|
||||||
|
ApproxTree {
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
// TODO: don't lock the whole tree
|
||||||
|
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
|
||||||
|
cache_threshold: f32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum PolicyConfig {
|
||||||
|
RandomConfig,
|
||||||
|
RoundRobinConfig,
|
||||||
|
ApproxTreeConfig {
|
||||||
|
tokenizer_path: String,
|
||||||
|
cache_threshold: f32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
|
||||||
|
// 1. convert body to json
|
||||||
|
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||||
|
// 2. get the text field
|
||||||
|
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
|
// 3. tokenize the text field
|
||||||
|
let tokens = tokenizer.encode(text, false).unwrap();
|
||||||
|
|
||||||
|
tokens.get_ids().to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
pub fn new(worker_urls: Vec<String>, policy: String) -> Self {
|
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
|
||||||
match policy.to_lowercase().as_str() {
|
match policy_config {
|
||||||
"random" => Router::Random { worker_urls },
|
PolicyConfig::RandomConfig => Router::Random { worker_urls },
|
||||||
"round_robin" => Router::RoundRobin {
|
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
|
||||||
worker_urls,
|
worker_urls,
|
||||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||||
},
|
},
|
||||||
_ => panic!(
|
PolicyConfig::ApproxTreeConfig {
|
||||||
"Unknown routing policy: {}. The available policies are 'random' and 'round_robin'",
|
tokenizer_path,
|
||||||
policy
|
cache_threshold,
|
||||||
),
|
} => {
|
||||||
|
let mut url_to_tree = HashMap::new();
|
||||||
|
let mut url_to_count = HashMap::new();
|
||||||
|
|
||||||
|
for url in &worker_urls {
|
||||||
|
url_to_tree.insert(url.clone(), RadixTree::new());
|
||||||
|
url_to_count.insert(url.clone(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Router::ApproxTree {
|
||||||
|
worker_urls,
|
||||||
|
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
|
||||||
|
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
|
||||||
|
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||||
|
url_to_count: Arc::new(Mutex::new(url_to_count)),
|
||||||
|
cache_threshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_first(&self) -> Option<String> {
|
pub fn get_first(&self) -> Option<String> {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } => {
|
Router::RoundRobin { worker_urls, .. }
|
||||||
|
| Router::Random { worker_urls }
|
||||||
|
| Router::ApproxTree { worker_urls, .. } => {
|
||||||
if worker_urls.is_empty() {
|
if worker_urls.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@@ -48,26 +100,96 @@ impl Router {
|
|||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let mut input_ids: Vec<u32> = Vec::new();
|
||||||
|
if let Router::ApproxTree { tokenizer, .. } = self {
|
||||||
|
input_ids = get_token_ids_from_request(&body, tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
let worker_url = match self {
|
let worker_url = match self {
|
||||||
Router::RoundRobin {
|
Router::RoundRobin {
|
||||||
worker_urls,
|
worker_urls,
|
||||||
current_index,
|
current_index,
|
||||||
} => {
|
} => {
|
||||||
current_index
|
let idx = current_index
|
||||||
.fetch_update(
|
.fetch_update(
|
||||||
std::sync::atomic::Ordering::SeqCst,
|
std::sync::atomic::Ordering::SeqCst,
|
||||||
std::sync::atomic::Ordering::SeqCst,
|
std::sync::atomic::Ordering::SeqCst,
|
||||||
|x| Some((x + 1) % worker_urls.len()),
|
|x| Some((x + 1) % worker_urls.len()),
|
||||||
)
|
)
|
||||||
.expect_err("Error updating index in round robin");
|
.unwrap();
|
||||||
|
|
||||||
&worker_urls[current_index.load(std::sync::atomic::Ordering::SeqCst)]
|
worker_urls[idx].clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
Router::Random { worker_urls } => {
|
Router::Random { worker_urls } => {
|
||||||
&worker_urls[rand::random::<usize>() % worker_urls.len()]
|
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
Router::ApproxTree {
|
||||||
|
worker_urls,
|
||||||
|
url_to_tree,
|
||||||
|
url_to_count,
|
||||||
|
cache_threshold,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// TODO: pipeline the locks. Release one earlier.
|
||||||
|
|
||||||
|
let mut max_matched_rate = 0.0;
|
||||||
|
let mut max_matched_idx = 0;
|
||||||
|
|
||||||
|
let locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||||
|
|
||||||
|
// 1. Find the highest matched worker
|
||||||
|
for (i, url) in worker_urls.iter().enumerate() {
|
||||||
|
let tree = locked_url_to_tree.get(url).unwrap();
|
||||||
|
let matched = tree.prefix_match(&input_ids[..]).len();
|
||||||
|
let matched_rate = matched as f32 / input_ids.len() as f32;
|
||||||
|
|
||||||
|
if matched_rate > max_matched_rate {
|
||||||
|
max_matched_rate = matched_rate;
|
||||||
|
max_matched_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
|
||||||
|
if max_matched_rate > *cache_threshold {
|
||||||
|
worker_urls[max_matched_idx].clone()
|
||||||
|
} else {
|
||||||
|
// pick the shortest queue from url_to_count
|
||||||
|
let locked_url_to_count = url_to_count.lock().unwrap();
|
||||||
|
|
||||||
|
let mut min_count = std::usize::MAX;
|
||||||
|
let mut min_count_id = 0;
|
||||||
|
|
||||||
|
for (i, url) in worker_urls.iter().enumerate() {
|
||||||
|
let count = locked_url_to_count.get(url).unwrap();
|
||||||
|
if *count < min_count {
|
||||||
|
min_count = *count;
|
||||||
|
min_count_id = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
worker_urls[min_count_id].clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Router::ApproxTree {
|
||||||
|
url_to_tree,
|
||||||
|
url_to_count,
|
||||||
|
..
|
||||||
|
} = self
|
||||||
|
{
|
||||||
|
// Insert input_ids to the tree
|
||||||
|
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||||
|
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
|
||||||
|
selected_tree.insert(&input_ids[..]);
|
||||||
|
|
||||||
|
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||||
|
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||||
|
*count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if client requested streaming
|
// Check if client requested streaming
|
||||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||||
@@ -94,11 +216,19 @@ impl Router {
|
|||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
if !is_stream {
|
if !is_stream {
|
||||||
|
// TODO: do the correction on the tree based on the cached input_ids
|
||||||
|
if let Router::ApproxTree { url_to_count, .. } = self {
|
||||||
|
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||||
|
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||||
|
*count -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
|
||||||
HttpResponse::build(status)
|
HttpResponse::build(status)
|
||||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||||
.streaming(res.bytes_stream().map_err(|_| {
|
.streaming(res.bytes_stream().map_err(|_| {
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::router::PolicyConfig;
|
||||||
use crate::router::Router;
|
use crate::router::Router;
|
||||||
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
@@ -9,9 +10,13 @@ pub struct AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
pub fn new(worker_urls: Vec<String>, policy: String, client: reqwest::Client) -> Self {
|
pub fn new(
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
client: reqwest::Client,
|
||||||
|
policy_config: PolicyConfig,
|
||||||
|
) -> Self {
|
||||||
// Create router based on policy
|
// Create router based on policy
|
||||||
let router = Router::new(worker_urls, policy);
|
let router = Router::new(worker_urls, policy_config);
|
||||||
|
|
||||||
Self { router, client }
|
Self { router, client }
|
||||||
}
|
}
|
||||||
@@ -40,7 +45,6 @@ async fn forward_request(
|
|||||||
|
|
||||||
#[get("/v1/models")]
|
#[get("/v1/models")]
|
||||||
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
|
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
|
||||||
// TODO: extract forward_to_route
|
|
||||||
let worker_url = match data.router.get_first() {
|
let worker_url = match data.router.get_first() {
|
||||||
Some(url) => url,
|
Some(url) => url,
|
||||||
None => return HttpResponse::InternalServerError().finish(),
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
@@ -59,7 +63,6 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
|||||||
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
|
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
// no deser and ser, just forward and return
|
|
||||||
#[post("/generate")]
|
#[post("/generate")]
|
||||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router.dispatch(&data.client, req, body).await
|
data.router.dispatch(&data.client, req, body).await
|
||||||
@@ -69,7 +72,7 @@ pub async fn startup(
|
|||||||
host: String,
|
host: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
routing_policy: String,
|
policy_config: PolicyConfig,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
println!("Starting server on {}:{}", host, port);
|
println!("Starting server on {}:{}", host, port);
|
||||||
println!("Worker URLs: {:?}", worker_urls);
|
println!("Worker URLs: {:?}", worker_urls);
|
||||||
@@ -80,7 +83,7 @@ pub async fn startup(
|
|||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
// Store both worker_urls and client in AppState
|
// Store both worker_urls and client in AppState
|
||||||
let app_state = web::Data::new(AppState::new(worker_urls, routing_policy, client));
|
let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config));
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
App::new()
|
App::new()
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug)]
|
||||||
pub struct Node {
|
pub struct Node {
|
||||||
pub children: HashMap<usize, Node>, // the key is first id of the child because each child must have unique first id
|
pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
|
||||||
pub ids: Vec<usize>,
|
pub ids: Vec<u32>,
|
||||||
pub count: usize,
|
pub count: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RadixTree {
|
pub struct RadixTree {
|
||||||
pub root: Node,
|
pub root: Node,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn common_prefix_len(a: &[usize], b: &[usize]) -> usize {
|
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i < a.len() && i < b.len() && a[i] == b[i] {
|
while i < a.len() && i < b.len() && a[i] == b[i] {
|
||||||
i += 1;
|
i += 1;
|
||||||
@@ -37,7 +38,7 @@ impl RadixTree {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(&mut self, input_ids: &[usize]) {
|
pub fn insert(&mut self, input_ids: &[u32]) {
|
||||||
let mut curr = &mut self.root;
|
let mut curr = &mut self.root;
|
||||||
curr.count += 1;
|
curr.count += 1;
|
||||||
|
|
||||||
@@ -93,7 +94,7 @@ impl RadixTree {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn prefix_match<'a>(&self, input_ids: &'a [usize]) -> &'a [usize] {
|
pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
|
||||||
let mut curr = &self.root;
|
let mut curr = &self.root;
|
||||||
|
|
||||||
let mut curr_idx = 0;
|
let mut curr_idx = 0;
|
||||||
@@ -121,7 +122,7 @@ impl RadixTree {
|
|||||||
&input_ids[..curr_idx]
|
&input_ids[..curr_idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete(&mut self, input_ids: &[usize]) {
|
pub fn delete(&mut self, input_ids: &[u32]) {
|
||||||
let mut curr = &mut self.root;
|
let mut curr = &mut self.root;
|
||||||
curr.count -= 1;
|
curr.count -= 1;
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ fn test_prefix_match_partial() {
|
|||||||
fn test_prefix_match_no_match() {
|
fn test_prefix_match_no_match() {
|
||||||
let mut tree = RadixTree::new();
|
let mut tree = RadixTree::new();
|
||||||
tree.insert(&[1, 2, 3, 4]);
|
tree.insert(&[1, 2, 3, 4]);
|
||||||
let empty_slices: &[usize] = &[];
|
let empty_slices: &[u32] = &[];
|
||||||
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
|
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ fn test_delete_nonexistent() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_empty_input() {
|
fn test_empty_input() {
|
||||||
let mut tree = RadixTree::new();
|
let mut tree = RadixTree::new();
|
||||||
let empty_slice: &[usize] = &[];
|
let empty_slice: &[u32] = &[];
|
||||||
tree.insert(empty_slice);
|
tree.insert(empty_slice);
|
||||||
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
|
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
|
||||||
tree.delete(empty_slice); // Should not panic
|
tree.delete(empty_slice); // Should not panic
|
||||||
|
|||||||
Reference in New Issue
Block a user