[router] cache-aware load-balancing router v1 (#2114)
This commit is contained in:
@@ -1,21 +1,24 @@
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenize
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def gen_prompt(tokenizer, token_num):
|
||||
all_available_tokens = list(tokenizer.get_vocab().values())
|
||||
@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
|
||||
return ret
|
||||
|
||||
|
||||
def get_cache_path(args):
|
||||
# Create cache directory under ~/.cache/sglang
|
||||
cache_dir = Path.home() / ".cache" / "sglang"
|
||||
|
||||
# Create a unique cache filename based on the arguments that affect generation
|
||||
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
|
||||
return cache_dir / cache_key
|
||||
|
||||
|
||||
def gen_arguments(args, tokenizer):
|
||||
multi_qas = [
|
||||
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
|
||||
for _ in range(args.num_qa)
|
||||
]
|
||||
for i in range(args.num_qa):
|
||||
cache_path = get_cache_path(args)
|
||||
|
||||
# Try to load from cache first
|
||||
if cache_path.exists():
|
||||
print(f"Loading cached arguments from {cache_path}")
|
||||
with open(cache_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
print("Generating new arguments...")
|
||||
# First progress bar for system prompts
|
||||
multi_qas = []
|
||||
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
|
||||
multi_qas.append(
|
||||
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
|
||||
)
|
||||
|
||||
# Nested progress bars for QA pairs
|
||||
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
|
||||
qas = multi_qas[i]["qas"]
|
||||
for j in range(args.turns):
|
||||
qas.append(
|
||||
@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer):
|
||||
"new_tokens": args.len_a,
|
||||
}
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, "w") as f:
|
||||
json.dump(multi_qas, f)
|
||||
print(f"Cached arguments saved to {cache_path}")
|
||||
|
||||
return multi_qas
|
||||
|
||||
|
||||
@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer):
|
||||
def multi_turns(s, system_prompt, qas):
|
||||
s += system_prompt
|
||||
|
||||
for qa in qas:
|
||||
for i, qa in enumerate(qas):
|
||||
s += qa["prompt"]
|
||||
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||
|
||||
@@ -62,7 +94,7 @@ def main(args):
|
||||
multi_qas,
|
||||
temperature=0,
|
||||
backend=backend,
|
||||
num_threads=args.parallel,
|
||||
num_threads="auto",
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.time() - tic
|
||||
@@ -75,7 +107,6 @@ def main(args):
|
||||
value = {
|
||||
"task": "multi_turn_system_prompt_chat",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_qa,
|
||||
"num_turns": args.turns,
|
||||
|
||||
@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for group_idx in range(num_groups):
|
||||
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
||||
system_prompt = system_prompts[group_idx]
|
||||
for prompt_idx in range(prompts_per_group):
|
||||
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
|
||||
question = questions[group_idx * prompts_per_group + prompt_idx]
|
||||
full_prompt = f"{system_prompt}\n\n{question}"
|
||||
prompt_len = len(tokenizer.encode(full_prompt))
|
||||
|
||||
@@ -48,9 +48,13 @@ def run_eval(args):
|
||||
# Select backend
|
||||
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
||||
|
||||
# Read data
|
||||
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
if args.data_path is None:
|
||||
# Read data
|
||||
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
else:
|
||||
filename = args.data_path
|
||||
|
||||
lines = list(read_jsonl(filename))
|
||||
|
||||
# Construct prompts
|
||||
|
||||
24
rust/Cargo.lock
generated
24
rust/Cargo.lock
generated
@@ -591,6 +591,20 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "6.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.3.11"
|
||||
@@ -904,6 +918,12 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.1"
|
||||
@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2097,7 +2117,9 @@ dependencies = [
|
||||
"actix-web",
|
||||
"bytes",
|
||||
"clap",
|
||||
"dashmap",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"pyo3",
|
||||
"rand",
|
||||
"reqwest",
|
||||
|
||||
@@ -24,6 +24,8 @@ futures-util = "0.3"
|
||||
serde_json = "1.0"
|
||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||
tokenizers = { version = "0.20.3", features = ["http"] }
|
||||
dashmap = "6.1.0"
|
||||
http = "1.1.0"
|
||||
|
||||
[profile.release]
|
||||
lto = "thin"
|
||||
|
||||
@@ -46,6 +46,9 @@ pip install <path-to-wheel>
|
||||
#### Option B: Development Mode
|
||||
|
||||
For development purposes, you can install the package in editable mode:
|
||||
|
||||
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
10
rust/demo.py
10
rust/demo.py
@@ -1,10 +0,0 @@
|
||||
from sglang_router import PolicyType, Router
|
||||
|
||||
router = Router(
|
||||
worker_urls=[
|
||||
"http://localhost:30000",
|
||||
"http://localhost:30001",
|
||||
]
|
||||
)
|
||||
|
||||
router.start()
|
||||
156
rust/dp_demo.py
156
rust/dp_demo.py
@@ -1,156 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
from sglang_router import PolicyType, Router
|
||||
|
||||
# Global processes list for cleanup
|
||||
_processes: List[subprocess.Popen] = []
|
||||
|
||||
|
||||
def cleanup_processes(signum=None, frame=None):
|
||||
"""Cleanup function to kill all worker processes."""
|
||||
print("\nCleaning up processes...")
|
||||
for process in _processes:
|
||||
try:
|
||||
# Kill the entire process group
|
||||
pgid = os.getpgid(process.pid)
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
process.wait()
|
||||
except:
|
||||
pass
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, cleanup_processes)
|
||||
signal.signal(signal.SIGTERM, cleanup_processes)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="localhost", help="Host address to bind the server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=30000, help="Base port number for workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of worker processes (degree of parallelism)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path", type=str, required=True, help="Path to the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-tokenizer-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the local tokenizer",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
|
||||
"""Launch all worker processes concurrently using subprocess."""
|
||||
processes = []
|
||||
worker_urls = []
|
||||
|
||||
# Launch each worker process
|
||||
for i in range(args.dp):
|
||||
port = args.port + i
|
||||
url = f"http://{args.host}:{port}"
|
||||
worker_urls.append(url)
|
||||
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
|
||||
# We don't
|
||||
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
|
||||
print(command)
|
||||
process = subprocess.Popen(command, shell=True)
|
||||
processes.append(process)
|
||||
_processes.append(process) # Add to global list for cleanup
|
||||
|
||||
return processes, worker_urls
|
||||
|
||||
|
||||
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
|
||||
"""Block until all workers are healthy or timeout is reached."""
|
||||
start_time = time.time()
|
||||
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
print("checking healthiness...")
|
||||
all_healthy = True
|
||||
|
||||
for url in worker_urls:
|
||||
if not healthy_workers[url]: # Only check workers that aren't healthy yet
|
||||
try:
|
||||
response = requests.get(f"{url}/health")
|
||||
if response.status_code == 200:
|
||||
print(f"Worker at {url} is healthy")
|
||||
healthy_workers[url] = True
|
||||
else:
|
||||
all_healthy = False
|
||||
except requests.RequestException:
|
||||
all_healthy = False
|
||||
|
||||
if all_healthy:
|
||||
print("All workers are healthy!")
|
||||
return True
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
# If we get here, we've timed out
|
||||
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
|
||||
print(f"Timeout waiting for workers: {unhealthy_workers}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to launch the router and workers."""
|
||||
args = parse_args()
|
||||
processes = None
|
||||
|
||||
try:
|
||||
# Launch all workers concurrently
|
||||
processes, worker_urls = launch_workers(args)
|
||||
|
||||
# Block until all workers are healthy
|
||||
if not wait_for_healthy_workers(worker_urls):
|
||||
raise RuntimeError("Failed to start all workers")
|
||||
|
||||
# Initialize and start the router
|
||||
router = Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=PolicyType.ApproxTree,
|
||||
tokenizer_path=args.local_tokenizer_path,
|
||||
)
|
||||
|
||||
print("Starting router...")
|
||||
router.start()
|
||||
|
||||
# Keep the main process running
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
# Cleanup: Kill all worker processes
|
||||
if processes:
|
||||
for process in processes:
|
||||
process.kill()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
204
rust/py_src/sglang_router/launch_router.py
Normal file
204
rust/py_src/sglang_router/launch_router.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router import Router
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str]
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
cache_threshold: float = 0.5
|
||||
cache_routing_prob: float = 1.0
|
||||
eviction_interval: int = 60
|
||||
max_tree_size: int = 2**24
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware"],
|
||||
help="Load balancing policy to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-routing-prob",
|
||||
type=float,
|
||||
default=RouterArgs.cache_routing_prob,
|
||||
help="Probability of using cache-aware routing (0.0-1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
return cls(
|
||||
worker_urls=args.worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
||||
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"),
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
)
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
Launch the SGLang router with the configuration from parsed arguments.
|
||||
|
||||
Args:
|
||||
args: Namespace object containing router configuration
|
||||
Can be either raw argparse.Namespace or converted RouterArgs
|
||||
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
router = Router(
|
||||
worker_urls=router_args.worker_urls,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
cache_threshold=router_args.cache_threshold,
|
||||
cache_routing_prob=router_args.cache_routing_prob,
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
)
|
||||
|
||||
router.start()
|
||||
return router
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting router: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
class CustomHelpFormatter(
|
||||
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
|
||||
):
|
||||
"""Custom formatter that preserves both description formatting and shows defaults"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def parse_router_args(args: List[str]) -> RouterArgs:
|
||||
"""Parse command line arguments and return RouterArgs instance."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""SGLang Router - High-performance request distribution across worker nodes
|
||||
|
||||
Usage:
|
||||
This launcher enables starting a router with individual worker instances. It is useful for
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
)
|
||||
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=False)
|
||||
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
router_args = parse_router_args(sys.argv[1:])
|
||||
router = launch_router(router_args)
|
||||
|
||||
if router is None:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
178
rust/py_src/sglang_router/launch_server.py
Normal file
178
rust/py_src/sglang_router/launch_server.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import argparse
|
||||
import copy
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
from sglang.srt.server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs, prepare_server_args
|
||||
from sglang.srt.utils import is_port_available
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
# Create new process group
|
||||
def run_server(server_args, dp_rank):
|
||||
os.setpgrp() # Create new process group
|
||||
|
||||
# Set DP_RANK environment variable
|
||||
os.environ["DP_RANK"] = str(dp_rank)
|
||||
|
||||
launch_server(server_args)
|
||||
|
||||
|
||||
def launch_server_process(
|
||||
server_args: ServerArgs, worker_port: int, dp_id: int
|
||||
) -> mp.Process:
|
||||
"""Launch a single server process with the given args and port."""
|
||||
server_args = copy.deepcopy(server_args)
|
||||
server_args.port = worker_port
|
||||
server_args.base_gpu_id = dp_id * server_args.tp_size
|
||||
server_args.dp_size = 1
|
||||
|
||||
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def cleanup_processes(processes: List[mp.Process]):
|
||||
"""Clean up all processes using process groups."""
|
||||
print("\nCleaning up processes...")
|
||||
for proc in processes:
|
||||
if proc.is_alive():
|
||||
try:
|
||||
# Kill the entire process group
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
# Give processes some time to terminate gracefully
|
||||
proc.join(timeout=3)
|
||||
# If process is still alive, force kill
|
||||
if proc.is_alive():
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Process already terminated
|
||||
|
||||
|
||||
def setup_signal_handlers(cleanup_func):
|
||||
"""Setup handlers for various termination signals."""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
cleanup_func()
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if hasattr(signal, "SIGQUIT"):
|
||||
signal.signal(signal.SIGQUIT, signal_handler)
|
||||
|
||||
|
||||
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
||||
"""Wait for server to be healthy by checking /health endpoint."""
|
||||
start_time = time.time()
|
||||
url = f"http://{host}:{port}/health"
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
"""Find consecutive available ports starting from base_port."""
|
||||
available_ports = []
|
||||
current_port = base_port
|
||||
|
||||
while len(available_ports) < count:
|
||||
if is_port_available(current_port):
|
||||
available_ports.append(current_port)
|
||||
current_port += 1
|
||||
|
||||
return available_ports
|
||||
|
||||
|
||||
def main():
|
||||
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch SGLang router and server processes"
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
||||
parser.add_argument(
|
||||
"--router-dp-worker-base-port",
|
||||
type=int,
|
||||
default=31000,
|
||||
help="Base port number for data parallel workers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Find available ports for workers
|
||||
worker_ports = find_available_ports(
|
||||
args.router_dp_worker_base_port, server_args.dp_size
|
||||
)
|
||||
|
||||
# Start server processes
|
||||
server_processes = []
|
||||
|
||||
try:
|
||||
# Launch server processes
|
||||
for i, worker_port in enumerate(worker_ports):
|
||||
proc = launch_server_process(server_args, worker_port, i)
|
||||
server_processes.append(proc)
|
||||
|
||||
# Setup cleanup handler
|
||||
setup_signal_handlers(lambda: cleanup_processes(server_processes))
|
||||
|
||||
# Wait for all servers to be healthy
|
||||
all_healthy = True
|
||||
for port in worker_ports:
|
||||
if not wait_for_server_health(server_args.host, port):
|
||||
print(f"Server on port {port} failed to become healthy")
|
||||
all_healthy = False
|
||||
break
|
||||
|
||||
if not all_healthy:
|
||||
print("Not all servers are healthy. Shutting down...")
|
||||
cleanup_processes(server_processes)
|
||||
sys.exit(1)
|
||||
|
||||
print("All servers are healthy. Starting router...")
|
||||
|
||||
# Update router args with worker URLs
|
||||
router_args.worker_urls = [
|
||||
f"http://{server_args.host}:{port}" for port in worker_ports
|
||||
]
|
||||
|
||||
# Start the router
|
||||
router = launch_router(router_args)
|
||||
|
||||
if router is None:
|
||||
print("Failed to start router. Shutting down...")
|
||||
cleanup_processes(server_processes)
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nReceived shutdown signal...")
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
print(get_exception_traceback())
|
||||
finally:
|
||||
cleanup_processes(server_processes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,16 +9,23 @@ class Router:
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
|
||||
Args:
|
||||
worker_urls: List of URLs for worker nodes that will handle requests
|
||||
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
||||
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
||||
policy: Load balancing policy to use. Options:
|
||||
- PolicyType.Random: Randomly select workers
|
||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity
|
||||
host: Host address to bind the router server
|
||||
port: Port number to bind the router server
|
||||
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy)
|
||||
cache_threshold: Caching threshold value between 0-1
|
||||
|
||||
- PolicyType.CacheAware: Distribute requests in cache-aware fashion
|
||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
||||
port: Port number to bind the router server. Default: 3001
|
||||
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
||||
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
||||
tree. Default: 0.5
|
||||
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for
|
||||
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven
|
||||
workloads, use a lower value to better distribute requests
|
||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||
routing. Default: 60
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -27,17 +34,20 @@ class Router:
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
cache_threshold: float = 0.50,
|
||||
cache_routing_prob: float = 1.0,
|
||||
eviction_interval_secs: int = 60,
|
||||
max_tree_size: int = 2**24,
|
||||
):
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
tokenizer_path=tokenizer_path,
|
||||
cache_threshold=cache_threshold,
|
||||
cache_routing_prob=cache_routing_prob,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// Python Binding
|
||||
use pyo3::prelude::*;
|
||||
pub mod router;
|
||||
mod server;
|
||||
pub mod server;
|
||||
pub mod tree;
|
||||
|
||||
#[pyclass(eq)]
|
||||
@@ -9,7 +8,7 @@ pub mod tree;
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
ApproxTree,
|
||||
CacheAware,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -18,8 +17,10 @@ struct Router {
|
||||
port: u16,
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
tokenizer_path: Option<String>,
|
||||
cache_threshold: Option<f32>,
|
||||
cache_threshold: f32,
|
||||
cache_routing_prob: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -30,33 +31,30 @@ impl Router {
|
||||
policy = PolicyType::RoundRobin,
|
||||
host = String::from("127.0.0.1"),
|
||||
port = 3001,
|
||||
tokenizer_path = None,
|
||||
cache_threshold = Some(0.50)
|
||||
cache_threshold = 0.50,
|
||||
cache_routing_prob = 1.0,
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24)
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
host: String,
|
||||
port: u16,
|
||||
tokenizer_path: Option<String>,
|
||||
cache_threshold: Option<f32>,
|
||||
cache_threshold: f32,
|
||||
cache_routing_prob: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
) -> PyResult<Self> {
|
||||
// Validate required parameters for approx_tree policy
|
||||
if matches!(policy, PolicyType::ApproxTree) {
|
||||
if tokenizer_path.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"tokenizer_path is required for approx_tree policy",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Router {
|
||||
host,
|
||||
port,
|
||||
worker_urls,
|
||||
policy,
|
||||
tokenizer_path,
|
||||
cache_threshold,
|
||||
cache_routing_prob,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -68,14 +66,11 @@ impl Router {
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig,
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
|
||||
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path: self
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.expect("tokenizer_path is required for approx_tree policy"),
|
||||
cache_threshold: self
|
||||
.cache_threshold
|
||||
.expect("cache_threshold is required for approx_tree policy"),
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold: self.cache_threshold,
|
||||
cache_routing_prob: self.cache_routing_prob,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
// src/main.rs
|
||||
use clap::Parser;
|
||||
use clap::ValueEnum;
|
||||
// declare child modules
|
||||
mod router;
|
||||
mod server;
|
||||
mod tree;
|
||||
|
||||
use crate::router::PolicyConfig;
|
||||
use sglang_router_rs::{router::PolicyConfig, server};
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
ApproxTree,
|
||||
CacheAware,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -21,44 +17,70 @@ struct Args {
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "127.0.0.1",
|
||||
help = "Host address to bind the server to"
|
||||
help = "Host address to bind the router server to. Default: 127.0.0.1"
|
||||
)]
|
||||
host: String,
|
||||
|
||||
#[arg(long, default_value_t = 3001, help = "Port number to listen on")]
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = 3001,
|
||||
help = "Port number to bind the router server to. Default: 3001"
|
||||
)]
|
||||
port: u16,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
help = "Comma-separated list of worker URLs to distribute requests to"
|
||||
help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
|
||||
)]
|
||||
worker_urls: Vec<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = PolicyType::RoundRobin,
|
||||
default_value_t = PolicyType::CacheAware,
|
||||
value_enum,
|
||||
help = "Load balancing policy to use: random, round_robin, or approx_tree"
|
||||
help = "Load balancing policy to use for request distribution:\n\
|
||||
- random: Randomly select workers\n\
|
||||
- round_robin: Distribute requests in round-robin fashion\n\
|
||||
- cache_aware: Distribute requests in cache-aware fashion\n"
|
||||
)]
|
||||
policy: PolicyType,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = 0.5,
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "approx_tree"),
|
||||
help = "Path to the tokenizer file, required when using approx_tree policy"
|
||||
required_if_eq("policy", "cache_aware"),
|
||||
help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
|
||||
)]
|
||||
tokenizer_path: Option<String>,
|
||||
cache_threshold: f32,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "0.50",
|
||||
default_value_t = 1.0,
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "approx_tree"),
|
||||
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
|
||||
required_if_eq("policy", "cache_aware"),
|
||||
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
|
||||
)]
|
||||
cache_threshold: Option<f32>,
|
||||
cache_routing_prob: f32,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = 60,
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "cache_aware"),
|
||||
help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
|
||||
)]
|
||||
eviction_interval_secs: u64,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value_t = 2usize.pow(24),
|
||||
requires = "policy",
|
||||
required_if_eq("policy", "cache_aware"),
|
||||
help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
|
||||
)]
|
||||
max_tree_size: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@@ -66,14 +88,11 @@ impl Args {
|
||||
match self.policy {
|
||||
PolicyType::Random => PolicyConfig::RandomConfig,
|
||||
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
|
||||
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path: self
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.expect("tokenizer_path is required for approx_tree policy"),
|
||||
cache_threshold: self
|
||||
.cache_threshold
|
||||
.expect("cache_threshold is required for approx_tree policy"),
|
||||
PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold: self.cache_threshold,
|
||||
cache_routing_prob: self.cache_routing_prob,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
use crate::tree::RadixTree;
|
||||
use crate::tree::Tree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::TryStreamExt;
|
||||
use futures_util::{Stream, StreamExt, TryStreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::Hash;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Router {
|
||||
@@ -18,34 +21,88 @@ pub enum Router {
|
||||
Random {
|
||||
worker_urls: Vec<String>,
|
||||
},
|
||||
ApproxTree {
|
||||
CacheAware {
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
|
||||
This router combines two strategies to optimize both cache utilization and request distribution:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
2. Load Balancing (Shortest Queue)
|
||||
|
||||
For each incoming request, the router chooses between these strategies:
|
||||
- With probability P: Uses cache-aware routing
|
||||
- With probability (1-P): Uses load balancing
|
||||
where P is configured via `cache_routing_prob`
|
||||
|
||||
Strategy Details:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
-------------------------------------------
|
||||
This strategy maintains an approximate radix tree for each worker based on request history,
|
||||
eliminating the need for direct cache state queries. The tree stores raw text characters
|
||||
instead of token IDs to avoid tokenization overhead.
|
||||
|
||||
Process:
|
||||
a. For each request, find the worker with the highest prefix match
|
||||
b. If match rate > cache_threshold:
|
||||
Route to the worker with highest match (likely has relevant data cached)
|
||||
c. If match rate ≤ cache_threshold:
|
||||
Route to the worker with smallest tree size (most available cache capacity)
|
||||
d. Background maintenance:
|
||||
Periodically evict least recently used leaf nodes to prevent memory overflow
|
||||
|
||||
2. Load Balancing (Shortest Queue)
|
||||
-------------------------------------------
|
||||
This strategy tracks pending request counts per worker and routes new requests
|
||||
to the least busy worker for optimal load distribution.
|
||||
|
||||
Configuration Parameters:
|
||||
------------------------
|
||||
1. cache_routing_prob: (float, 0.0 to 1.0)
|
||||
- 0.0: Exclusively use load balancing
|
||||
- 1.0: Exclusively use cache-aware routing
|
||||
- Between 0-1: Probability of using cache-aware routing vs load balancing
|
||||
|
||||
2. cache_threshold: (float, 0.0 to 1.0)
|
||||
Minimum prefix match ratio to use highest-match routing.
|
||||
Below this threshold, routes to worker with most available cache space.
|
||||
|
||||
3. eviction_interval_secs: (integer)
|
||||
Interval between LRU eviction cycles for the approximate trees.
|
||||
|
||||
4. max_tree_size: (integer)
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
worker_urls: Vec<String>,
|
||||
// TODO: don't lock the whole tree
|
||||
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
|
||||
tokenizer: Tokenizer,
|
||||
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
running_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
cache_threshold: f32,
|
||||
cache_routing_prob: f32,
|
||||
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PolicyConfig {
|
||||
RandomConfig,
|
||||
RoundRobinConfig,
|
||||
ApproxTreeConfig {
|
||||
tokenizer_path: String,
|
||||
CacheAwareConfig {
|
||||
cache_threshold: f32,
|
||||
cache_routing_prob: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
|
||||
fn get_text_from_request(body: &Bytes) -> String {
|
||||
// 1. convert body to json
|
||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||
// 2. get the text field
|
||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
// 3. tokenize the text field
|
||||
let tokens = tokenizer.encode(text, false).unwrap();
|
||||
|
||||
tokens.get_ids().to_vec()
|
||||
return text.to_string();
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -56,25 +113,56 @@ impl Router {
|
||||
worker_urls,
|
||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||
},
|
||||
PolicyConfig::ApproxTreeConfig {
|
||||
tokenizer_path,
|
||||
PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold,
|
||||
cache_routing_prob,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let mut url_to_tree = HashMap::new();
|
||||
let mut url_to_count = HashMap::new();
|
||||
|
||||
let mut running_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
url_to_tree.insert(url.clone(), RadixTree::new());
|
||||
url_to_count.insert(url.clone(), 0);
|
||||
running_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
Router::ApproxTree {
|
||||
let mut processed_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
processed_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
let running_queue = Arc::new(Mutex::new(running_queue));
|
||||
let processed_queue = Arc::new(Mutex::new(processed_queue));
|
||||
|
||||
// Create background eviction thread
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let processed_queue_clone = Arc::clone(&processed_queue);
|
||||
let eviction_thread = thread::spawn(move || {
|
||||
loop {
|
||||
// Sleep for the specified interval
|
||||
thread::sleep(Duration::from_secs(eviction_interval_secs));
|
||||
|
||||
let locked_tree_clone = tree_clone.lock().unwrap();
|
||||
// Run eviction
|
||||
locked_tree_clone.evict_tenant_data(max_tree_size);
|
||||
|
||||
// Print the process queue
|
||||
let locked_processed_queue = processed_queue_clone.lock().unwrap();
|
||||
println!("Processed Queue: {:?}", locked_processed_queue);
|
||||
}
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
worker_urls,
|
||||
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
|
||||
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
|
||||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||
url_to_count: Arc::new(Mutex::new(url_to_count)),
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
cache_routing_prob,
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -84,7 +172,7 @@ impl Router {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::ApproxTree { worker_urls, .. } => {
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
if worker_urls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -100,10 +188,7 @@ impl Router {
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
) -> HttpResponse {
|
||||
let mut input_ids: Vec<u32> = Vec::new();
|
||||
if let Router::ApproxTree { tokenizer, .. } = self {
|
||||
input_ids = get_token_ids_from_request(&body, tokenizer);
|
||||
}
|
||||
let text = get_text_from_request(&body);
|
||||
|
||||
let worker_url = match self {
|
||||
Router::RoundRobin {
|
||||
@@ -125,78 +210,73 @@ impl Router {
|
||||
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
|
||||
}
|
||||
|
||||
Router::ApproxTree {
|
||||
Router::CacheAware {
|
||||
worker_urls,
|
||||
url_to_tree,
|
||||
url_to_count,
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
cache_routing_prob,
|
||||
..
|
||||
} => {
|
||||
// TODO: pipeline the locks. Release one earlier.
|
||||
// even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
|
||||
|
||||
let mut max_matched_rate = 0.0;
|
||||
let mut max_matched_idx = 0;
|
||||
let mut tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
|
||||
let locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||
// Generate a random float between 0 and 1 for probability check
|
||||
let sampled_p: f32 = rand::random();
|
||||
|
||||
// 1. Find the highest matched worker
|
||||
for (i, url) in worker_urls.iter().enumerate() {
|
||||
let tree = locked_url_to_tree.get(url).unwrap();
|
||||
let matched = tree.prefix_match(&input_ids[..]).len();
|
||||
let matched_rate = matched as f32 / input_ids.len() as f32;
|
||||
let selected_url = if sampled_p < *cache_routing_prob {
|
||||
// Cache-aware routing logic
|
||||
let (matched_text, matched_worker) = tree.prefix_match(&text);
|
||||
let matched_rate =
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32;
|
||||
|
||||
if matched_rate > max_matched_rate {
|
||||
max_matched_rate = matched_rate;
|
||||
max_matched_idx = i;
|
||||
if matched_rate > *cache_threshold {
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
let m_map: HashMap<String, usize> = tree
|
||||
.tenant_char_count
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), *entry.value()))
|
||||
.collect();
|
||||
|
||||
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
|
||||
|
||||
tree.get_smallest_tenant()
|
||||
}
|
||||
}
|
||||
|
||||
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
|
||||
if max_matched_rate > *cache_threshold {
|
||||
worker_urls[max_matched_idx].clone()
|
||||
} else {
|
||||
// pick the shortest queue from url_to_count
|
||||
let locked_url_to_count = url_to_count.lock().unwrap();
|
||||
// Shortest queue routing logic
|
||||
running_queue
|
||||
.iter()
|
||||
.min_by_key(|(_url, &count)| count)
|
||||
.map(|(url, _)| url.clone())
|
||||
.unwrap_or_else(|| worker_urls[0].clone())
|
||||
};
|
||||
|
||||
let mut min_count = std::usize::MAX;
|
||||
let mut min_count_id = 0;
|
||||
// Update running queue
|
||||
let count = running_queue.get_mut(&selected_url).unwrap();
|
||||
*count += 1;
|
||||
|
||||
for (i, url) in worker_urls.iter().enumerate() {
|
||||
let count = locked_url_to_count.get(url).unwrap();
|
||||
if *count < min_count {
|
||||
min_count = *count;
|
||||
min_count_id = i;
|
||||
}
|
||||
}
|
||||
// Update processed queue
|
||||
let mut locked_processed_queue = processed_queue.lock().unwrap();
|
||||
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
|
||||
*count += 1;
|
||||
|
||||
worker_urls[min_count_id].clone()
|
||||
}
|
||||
// Update tree with the new request
|
||||
tree.insert(&text, &selected_url);
|
||||
|
||||
selected_url
|
||||
}
|
||||
};
|
||||
|
||||
if let Router::ApproxTree {
|
||||
url_to_tree,
|
||||
url_to_count,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
// Insert input_ids to the tree
|
||||
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
|
||||
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
|
||||
selected_tree.insert(&input_ids[..]);
|
||||
|
||||
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
// Check if client requested streaming
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
let res = match client
|
||||
.post(format!("{}/generate", worker_url))
|
||||
.post(format!("{}/generate", worker_url.clone()))
|
||||
.header(
|
||||
"Content-Type",
|
||||
req.headers()
|
||||
@@ -216,23 +296,53 @@ impl Router {
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !is_stream {
|
||||
// TODO: do the correction on the tree based on the cached input_ids
|
||||
if let Router::ApproxTree { url_to_count, .. } = self {
|
||||
let mut locked_url_to_count = url_to_count.lock().unwrap();
|
||||
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
|
||||
*count -= 1;
|
||||
}
|
||||
|
||||
match res.bytes().await {
|
||||
// For non-streaming requests, get response first
|
||||
let response = match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
// Then decrement running queue counter if using CacheAware
|
||||
if let Router::CacheAware { running_queue, .. } = self {
|
||||
if let Ok(mut queue) = running_queue.lock() {
|
||||
if let Some(count) = queue.get_mut(&worker_url) {
|
||||
*count = count.saturating_sub(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
let worker_url = worker_url.clone();
|
||||
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(
|
||||
res.bytes_stream()
|
||||
.map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
})
|
||||
.inspect(move |bytes| {
|
||||
let bytes = bytes.as_ref().unwrap();
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
let mut locked_queue = running_queue.lock().unwrap();
|
||||
let count = locked_queue.get_mut(&worker_url).unwrap();
|
||||
*count = count.saturating_sub(1);
|
||||
// print
|
||||
// println!("streaming is done!!")
|
||||
}
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(res.bytes_stream().map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read string")
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +76,7 @@ pub async fn startup(
|
||||
) -> std::io::Result<()> {
|
||||
println!("Starting server on {}:{}", host, port);
|
||||
println!("Worker URLs: {:?}", worker_urls);
|
||||
println!("Policy Config: {:?}", policy_config);
|
||||
|
||||
// Create client once with configuration
|
||||
let client = reqwest::Client::builder()
|
||||
|
||||
1343
rust/src/tree.rs
1343
rust/src/tree.rs
File diff suppressed because it is too large
Load Diff
@@ -1,131 +0,0 @@
|
||||
use sglang_router_rs::tree::RadixTree;
|
||||
|
||||
#[test]
|
||||
fn test_new_tree() {
|
||||
let tree = RadixTree::new();
|
||||
assert_eq!(tree.root.count, 0);
|
||||
assert!(tree.root.children.is_empty());
|
||||
assert!(tree.root.ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_insertion() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3]);
|
||||
|
||||
assert_eq!(tree.root.count, 1);
|
||||
assert_eq!(tree.root.children.len(), 1);
|
||||
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
|
||||
assert_eq!(tree.root.children[&1].count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_insertions_no_split() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3]);
|
||||
tree.insert(&[4, 5, 6]);
|
||||
|
||||
assert_eq!(tree.root.count, 2);
|
||||
assert_eq!(tree.root.children.len(), 2);
|
||||
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
|
||||
assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insertion_with_split() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3, 4]);
|
||||
tree.insert(&[1, 2, 5, 6]);
|
||||
|
||||
assert_eq!(tree.root.count, 2);
|
||||
assert_eq!(tree.root.children.len(), 1);
|
||||
assert_eq!(tree.root.children[&1].ids, vec![1, 2]);
|
||||
assert_eq!(tree.root.children[&1].children.len(), 2);
|
||||
assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]);
|
||||
assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefix_match_exact() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3, 4]);
|
||||
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefix_match_partial() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3, 4]);
|
||||
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]);
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]);
|
||||
assert_eq!(tree.prefix_match(&[1, 5]), &[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefix_match_no_match() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3, 4]);
|
||||
let empty_slices: &[u32] = &[];
|
||||
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_leaf() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3]);
|
||||
tree.delete(&[1, 2, 3]);
|
||||
|
||||
assert_eq!(tree.root.count, 0);
|
||||
assert_eq!(tree.root.children.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_with_siblings() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3]);
|
||||
tree.insert(&[1, 2, 4]);
|
||||
tree.delete(&[1, 2, 3]);
|
||||
|
||||
assert_eq!(tree.root.count, 1);
|
||||
assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_operations() {
|
||||
let mut tree = RadixTree::new();
|
||||
|
||||
// Insert several paths
|
||||
tree.insert(&[1, 2, 3]);
|
||||
tree.insert(&[1, 2, 4]);
|
||||
tree.insert(&[1, 5, 6]);
|
||||
|
||||
// Verify structure
|
||||
assert_eq!(tree.root.count, 3);
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]);
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]);
|
||||
assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]);
|
||||
|
||||
// Delete and verify
|
||||
tree.delete(&[1, 2, 3]);
|
||||
assert_eq!(tree.root.count, 2);
|
||||
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "No match found")]
|
||||
fn test_delete_nonexistent() {
|
||||
let mut tree = RadixTree::new();
|
||||
tree.insert(&[1, 2, 3]);
|
||||
tree.delete(&[4, 5, 6]); // Should panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let mut tree = RadixTree::new();
|
||||
let empty_slice: &[u32] = &[];
|
||||
tree.insert(empty_slice);
|
||||
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
|
||||
tree.delete(empty_slice); // Should not panic
|
||||
}
|
||||
Reference in New Issue
Block a user