Rename rust folder to sgl-router (#2464)

Signed-off-by: Ata Fatahi <immrata@gmail.com>
This commit is contained in:
Ata Fatahi
2024-12-12 12:40:41 -05:00
committed by GitHub
parent 2673fa29d4
commit e3b3acfa6f
22 changed files with 13 additions and 13 deletions

3106
sgl-router/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

32
sgl-router/Cargo.toml Normal file
View File

@@ -0,0 +1,32 @@
[package]
name = "sglang_router_rs"
version = "0.0.0"
edition = "2021"
[lib]
name = "sglang_router_rs"
# Pure Rust library: Just omit crate-type (defaults to rlib)
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
crate-type = ["cdylib", "rlib"]
[dependencies]
actix-web = "4.0"
serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.4", features = ["derive"] }
bytes = "1.8.0"
rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
futures-util = "0.3"
serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0"
http = "1.1.0"
env_logger = "0.11.5"
log = "0.4.22"
chrono = "0.4.38"
tokio = "1.42.0"
[profile.release]
lto = "thin"
codegen-units = 1

3
sgl-router/MANIFEST.in Normal file
View File

@@ -0,0 +1,3 @@
# Must include:
include Cargo.toml # Rust project configuration
recursive-include src *.rs # Rust source files

87
sgl-router/README.md Normal file
View File

@@ -0,0 +1,87 @@
# SGLang Router
SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances.
## User docs
Please check https://sgl-project.github.io/router/router.html
## Developer docs
### Prerequisites
- Rust and Cargo installed
```bash
# Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# Follow the installation prompts, then reload your shell
source $HOME/.cargo/env
# Verify installation
rustc --version
cargo --version
```
- Python with pip installed
### Build Process
#### 1. Build Rust Project
```bash
$ cargo build
```
#### 2. Build Python Binding
##### Option A: Build and Install Wheel
1. Build the wheel package:
```bash
$ pip install setuptools-rust wheel build
$ python -m build
```
2. Install the generated wheel:
```bash
$ pip install <path-to-wheel>
```
If you want one handy command to do build + install for every change you make:
```bash
$ python -m build && pip install --force-reinstall dist/*.whl
```
##### Option B: Development Mode
For development purposes, you can install the package in editable mode:
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
```bash
$ pip install -e .
```
**Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect.
### CI/CD Setup
The continuous integration pipeline consists of three main steps:
#### 1. Build Wheels
- Uses `cibuildwheel` to create manylinux x86_64 packages
- Compatible with major Linux distributions (Ubuntu, CentOS, etc.)
- Additional configurations can be added to support other OS/architectures
- Reference: [cibuildwheel documentation](https://cibuildwheel.pypa.io/en/stable/)
#### 2. Build Source Distribution
- Creates a source distribution containing the raw, unbuilt code
- Enables `pip` to build the package from source when prebuilt wheels are unavailable
#### 3. Publish to PyPI
- Uploads both wheels and source distribution to PyPI
The CI configuration is based on the [tiktoken workflow](https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/.github/workflows/build_wheels.yml#L1).

View File

@@ -0,0 +1,11 @@
# a lightweihgt wrapper on router with argument type and comments
from sglang_router_rs import PolicyType
# no wrapper on policy type => direct export
from .router import Router
__all__ = ["Router", "PolicyType"]
from sglang_router.version import __version__
__all__ += ["__version__"]

View File

@@ -0,0 +1,249 @@
import argparse
import dataclasses
import logging
import sys
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str]
host: str = "127.0.0.1"
port: int = 30000
# Routing policy
policy: str = "cache_aware"
cache_threshold: float = 0.5
balance_abs_threshold: int = 32
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
verbose: bool = False
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="+",
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float,
default=RouterArgs.balance_rel_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}verbose",
action="store_true",
help="Enable verbose logging",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
return cls(
worker_urls=args.worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
verbose=getattr(args, f"{prefix}verbose", False),
)
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
logger = logging.getLogger("router")
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
router = Router(
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
verbose=router_args.verbose,
)
router.start()
return router
except Exception as e:
logger.error(f"Error starting router: {e}")
return None
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
logger = setup_logger()
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,180 @@
import argparse
import copy
import logging
import multiprocessing as mp
import os
import random
import signal
import sys
import time
from typing import List
import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
# Create new process group
def run_server(server_args, dp_rank):
"""
Note:
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
Terminal (PGID=100)
└── Main Python Process (PGID=100)
└── Server Process 1 (PGID=100)
└── Scheduler 1
└── Detokenizer 1
└── Server Process 2 (PGID=100)
└── Scheduler 2
└── Detokenizer 2
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
Terminal (PGID=100)
└── Main Python Process (PGID=200)
└── Server Process 1 (PGID=300)
└── Scheduler 1
└── Detokenizer 1
└── Server Process 2 (PGID=400)
└── Scheduler 2
└── Detokenizer 2
"""
# create new process group
os.setpgrp()
setproctitle(f"sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
logger.info(f"Terminating process {process.pid}")
process.terminate()
logger.info("All processes terminated")
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
signal.signal(
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
signal.signal(
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,67 @@
from typing import List, Optional
from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router
class Router:
"""
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
"""
def __init__(
self,
worker_urls: List[str],
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
cache_threshold: float = 0.50,
balance_abs_threshold: int = 32,
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB
verbose: bool = False,
):
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
verbose=verbose,
)
def start(self) -> None:
"""Start the router server.
This method blocks until the server is shut down.
"""
self._router.start()

View File

@@ -0,0 +1 @@
__version__ = "0.1.1"

View File

@@ -0,0 +1,19 @@
import argparse
import glob
from sglang.test.test_utils import run_unittest_files
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=1000,
help="The time limit for running one file in seconds.",
)
args = arg_parser.parse_args()
files = glob.glob("**/test_*.py", recursive=True)
exit_code = run_unittest_files(files, args.timeout_per_file)
exit(exit_code)

View File

@@ -0,0 +1,67 @@
import multiprocessing
import time
import unittest
from types import SimpleNamespace
def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
"""Terminate a process gracefully, with forced kill as fallback.
Args:
process: The process to terminate
timeout: Seconds to wait for graceful termination before forcing kill
"""
if not process.is_alive():
return
process.terminate()
process.join(timeout=timeout)
if process.is_alive():
process.kill() # Force kill if terminate didn't work
process.join()
class TestLaunchRouter(unittest.TestCase):
def test_launch_router_no_exception(self):
# Create SimpleNamespace with default arguments
args = SimpleNamespace(
worker_urls=["http://localhost:8000"],
host="127.0.0.1",
port=30000,
policy="cache_aware",
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
eviction_interval=60,
max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB
verbose=False,
)
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
# Start router in separate process
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Process is still running means router started successfully
self.assertTrue(process.is_alive())
finally:
terminate_process(process)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,338 @@
import socket
import subprocess
import time
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
)
def popen_launch_router(
model: str,
base_url: str,
dp_size: int,
timeout: float,
policy: str = "cache_aware",
max_payload_size: int = None,
):
"""
Launch the router server process.
Args:
model: Model path/name
base_url: Server base URL
dp_size: Data parallel size
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
"""
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
"sglang_router.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--dp",
str(dp_size),
"--router-eviction-interval",
"5",
"--router-policy",
policy,
]
if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Router {base_url} is healthy")
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Router failed to start within the timeout period.")
def find_available_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def popen_launch_server(
model: str,
base_url: str,
timeout: float,
):
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"1",
]
process = subprocess.Popen(command, stdout=None, stderr=None)
# intentionally don't wait and defer the job to the router health check
return process
def terminate_and_wait(process, timeout=300):
"""Terminate a process and wait until it is terminated.
Args:
process: subprocess.Popen object
timeout: maximum time to wait in seconds
Raises:
TimeoutError: if process does not terminate within timeout
"""
if process is None:
return
process.terminate()
start_time = time.time()
while process.poll() is None:
print(f"Terminating process {process.pid}")
if time.time() - start_time > timeout:
raise TimeoutError(
f"Process {process.pid} failed to terminate within {timeout}s"
)
time.sleep(1)
print(f"Process {process.pid} is successfully terminated")
class TestLaunchServer(unittest.TestCase):
def setUp(self):
self.model = DEFAULT_MODEL_NAME_FOR_TEST
self.base_url = DEFAULT_URL_FOR_TEST
self.process = None
self.other_process = []
def tearDown(self):
print("Running tearDown...")
if self.process:
terminate_and_wait(self.process)
for process in self.other_process:
terminate_and_wait(process)
print("tearDown done")
def test_1_mmlu(self):
print("Running test_1_mmlu...")
# DP size = 2
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=2,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="cache_aware",
)
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_2_add_and_remove_worker(self):
print("Running test_2_add_and_remove_worker...")
# DP size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin", # use round robin to make sure every worker processes requests
)
# 1. start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 4. use /remove_worker api to remove it from the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 5. run mmlu again
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_3_lazy_fault_tolerance(self):
print("Running test_3_lazy_fault_tolerance...")
# DP size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
)
# 1. start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure
def kill_worker():
time.sleep(10)
kill_process_tree(worker_process.pid)
print("Worker process killed")
import threading
kill_thread = threading.Thread(target=kill_worker)
kill_thread.daemon = True
kill_thread.start()
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=256,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_4_payload_size(self):
print("Running test_4_payload_size...")
# Start router with 3MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)
if __name__ == "__main__":
unittest.main()

26
sgl-router/pyproject.toml Normal file
View File

@@ -0,0 +1,26 @@
[build-system]
requires = ["setuptools>=45", "wheel", "setuptools-rust>=1.5.2"]
build-backend = "setuptools.build_meta"
[project]
name = "sglang-router"
version = "0.1.1"
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
requires-python = ">=3.8"
readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Rust",
"Programming Language :: Python :: 3",
]
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
[tool.setuptools.packages]
find = { where = ["py_src"] }
[[tool.setuptools-rust.ext-modules]]
target = "sglang_router_rs"
path = "Cargo.toml"
binding = "PyO3"

108
sgl-router/src/lib.rs Normal file
View File

@@ -0,0 +1,108 @@
use pyo3::prelude::*;
pub mod router;
pub mod server;
pub mod tree;
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
}
#[pyclass]
struct Router {
host: String,
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
}
#[pymethods]
impl Router {
#[new]
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
cache_threshold = 0.50,
balance_abs_threshold = 32,
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
verbose = false
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
) -> PyResult<Self> {
Ok(Router {
host,
port,
worker_urls,
policy,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
max_payload_size,
verbose,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
};
actix_web::rt::System::new().block_on(async move {
server::startup(server::ServerConfig {
host: self.host.clone(),
port: self.port,
worker_urls: self.worker_urls.clone(),
policy_config,
verbose: self.verbose,
max_payload_size: self.max_payload_size,
})
.await
.unwrap();
});
Ok(())
}
}
#[pymodule]
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?;
Ok(())
}

702
sgl-router/src/router.rs Normal file
View File

@@ -0,0 +1,702 @@
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt};
use log::{debug, info, warn};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use tokio;
#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
},
Random {
worker_urls: Arc<RwLock<Vec<String>>>,
},
CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue with Balance Thresholds)
The router dynamically switches between these strategies based on load conditions:
- Uses load balancing when the system is imbalanced
- Uses cache-aware routing when the system is balanced
A system is considered imbalanced if both conditions are met:
1. (max - min) > abs_threshold
2. max > rel_threshold * min
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker when the system is detected to be imbalanced.
Configuration Parameters:
------------------------
1. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
2. balance_abs_threshold: (integer)
Absolute difference threshold for load imbalance detection.
System is potentially imbalanced if (max_load - min_load) > abs_threshold
3. balance_rel_threshold: (float)
Relative ratio threshold for load imbalance detection.
System is potentially imbalanced if max_load > min_load * rel_threshold
Used in conjunction with abs_threshold to determine final imbalance state.
4. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
5. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
_eviction_thread: Option<thread::JoinHandle<()>>,
},
}
#[derive(Debug, Clone)]
pub enum PolicyConfig {
RandomConfig,
RoundRobinConfig,
CacheAwareConfig {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
},
}
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Wait until all workers are healthy
Self::wait_for_healthy_workers(&worker_urls, 300, 10)?;
// Create router based on policy...
Ok(match policy_config {
PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::CacheAwareConfig {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let running_queue_clone = Arc::clone(&running_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_by_size(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
info!("Processed Queue: {:?}", locked_processed_queue);
// Print the running queue
let locked_running_queue = running_queue_clone.lock().unwrap();
info!("Running Queue: {:?}", locked_running_queue);
}
});
for url in &worker_urls {
tree.lock().unwrap().insert(&"".to_string(), url);
}
Router::CacheAware {
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
_eviction_thread: Some(eviction_thread),
}
}
})
}
fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::new();
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
return Err(format!(
"Timeout {}s waiting for workers to become healthy",
timeout_secs
));
}
let mut all_healthy = true;
let mut unhealthy_workers = Vec::new();
for url in worker_urls {
match sync_client.get(&format!("{}/health", url)).send() {
Ok(res) => {
if !res.status().is_success() {
info!(
"Worker {} health check is pending with status: {}.",
url,
res.status()
);
all_healthy = false;
unhealthy_workers.push((url, format!("Status: {}", res.status())));
}
}
Err(e) => {
info!("Worker {} health check is pending with error: {}", url, e);
all_healthy = false;
unhealthy_workers.push((url, format!("Error: {}", e)));
}
}
}
if all_healthy {
info!("All workers are healthy");
return Ok(());
} else {
info!("Unhealthy workers:");
for (url, reason) in &unhealthy_workers {
info!(" {} - {}", url, reason);
}
thread::sleep(Duration::from_secs(interval_secs));
}
}
}
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
}
}
async fn send_request(
&self,
client: &reqwest::Client,
worker_url: &str,
route: &str,
) -> HttpResponse {
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
)),
}
}
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
match self.select_first_worker() {
Ok(worker_url) => {
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
}
let response = self.send_request(client, &worker_url, route).await;
if response.status().is_success() {
return response;
}
warn!(
"Request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
Err(e) => return HttpResponse::InternalServerError().body(e),
}
}
HttpResponse::InternalServerError().body("All retry attempts failed")
}
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}
return "".to_string();
}
// TODO: return Result<String, String> instead of panicking
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
let text = self.get_text_from_request(&body, route);
let worker_url = match self {
Router::RoundRobin {
worker_urls,
current_index,
} => {
let idx = current_index
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls.read().unwrap()[idx].clone()
}
Router::Random { worker_urls } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),
Router::CacheAware {
worker_urls,
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
..
} => {
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
// Get current load statistics
let max_load = *running_queue.values().max().unwrap_or(&0);
let min_load = *running_queue.values().min().unwrap_or(&0);
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
// 2. max > rel_threshold * min
let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * balance_rel_threshold);
let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state
info!(
"Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\
Current running queue: {:?}",
max_load, min_load, running_queue
);
// Use shortest queue routing when load is imbalanced
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched_rate =
matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > *cache_threshold {
matched_worker.to_string()
} else {
tree.get_smallest_tenant()
}
};
// Update queues and tree
*running_queue.get_mut(&selected_url).unwrap() += 1;
*processed_queue
.lock()
.unwrap()
.get_mut(&selected_url)
.unwrap() += 1;
tree.insert(&text, &selected_url);
selected_url
}
};
worker_url
}
async fn send_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
worker_url: &str,
) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);
let res = match client
.post(format!("{}{}", worker_url, route))
.header(
"Content-Type",
req.headers()
.get("Content-Type")
.and_then(|h| h.to_str().ok())
.unwrap_or("application/json"),
)
.body(body.to_vec())
.send()
.await
{
Ok(res) => res,
Err(_) => return HttpResponse::InternalServerError().finish(),
};
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
}
}
}
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.to_string();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
debug!("Streaming is done!!")
}
}),
)
} else {
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
}))
}
}
pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
) -> HttpResponse {
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
let worker_url = self.select_generate_worker(body, route);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
}
let response = self
.send_generate_request(client, req, body, route, &worker_url)
.await;
if response.status().is_success() {
return response;
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
HttpResponse::InternalServerError().body("All retry attempts failed")
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
return Err(format!(
"Timeout {}s waiting for worker {} to become healthy",
timeout_secs, worker_url
));
}
match client.get(&format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap();
if urls.contains(&worker_url.to_string()) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
urls.push(worker_url.to_string());
}
}
// If cache aware, initialize the queues for the new worker
if let Router::CacheAware {
running_queue,
processed_queue,
tree,
..
} = self
{
// Add worker to running queue with initial count of 0
running_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to processed queue with initial count of 0
processed_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to tree
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
}
return Ok(format!("Successfully added worker: {}", worker_url));
} else {
info!(
"Worker {} health check is pending with status: {}.",
worker_url,
res.status()
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
{
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
}
Err(e) => {
info!(
"Worker {} health check is pending with error: {}",
worker_url, e
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
}
}
}
pub fn remove_worker(&self, worker_url: &str) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
urls.remove(index);
info!("Removed worker: {}", worker_url);
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
}
}
// if cache aware, remove the worker from the tree
if let Router::CacheAware {
tree,
running_queue,
processed_queue,
..
} = self
{
tree.lock().unwrap().remove_tenant(&worker_url);
running_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
processed_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
info!(
"Removed worker from tree and cleaned up queues: {}",
worker_url
);
}
}
}

192
sgl-router/src/server.rs Normal file
View File

@@ -0,0 +1,192 @@
use crate::router::PolicyConfig;
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use bytes::Bytes;
use env_logger::Builder;
use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write;
#[derive(Debug)]
pub struct AppState {
router: Router,
client: reqwest::Client,
}
impl AppState {
pub fn new(
worker_urls: Vec<String>,
client: reqwest::Client,
policy_config: PolicyConfig,
) -> Self {
// Create router based on policy
let router = match Router::new(worker_urls, policy_config) {
Ok(router) => router,
Err(error) => panic!("Failed to create router: {}", error),
};
Self { router, client }
}
}
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
data.router.route_to_first(&data.client, "/health").await
}
#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/health_generate")
.await
}
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_server_info")
.await
}
#[get("/v1/models")]
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
data.router.route_to_first(&data.client, "/v1/models").await
}
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_model_info")
.await
}
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/generate")
.await
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
.await
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/completions")
.await
}
#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
match data.router.add_worker(&worker_url).await {
Ok(message) => HttpResponse::Ok().body(message),
Err(error) => HttpResponse::BadRequest().body(error),
}
}
#[post("/remove_worker")]
async fn remove_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => return HttpResponse::BadRequest().finish(),
};
data.router.remove_worker(&worker_url);
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub worker_urls: Vec<String>,
pub policy_config: PolicyConfig,
pub verbose: bool,
pub max_payload_size: usize,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
Builder::new()
.format(|buf, record| {
use chrono::Local;
writeln!(
buf,
"[Router (Rust)] {} - {} - {}",
Local::now().format("%Y-%m-%d %H:%M:%S"),
record.level(),
record.args()
)
})
.filter(
None,
if config.verbose {
LevelFilter::Debug
} else {
LevelFilter::Info
},
)
.init();
let client = reqwest::Client::builder()
.build()
.expect("Failed to create HTTP client");
let app_state = web::Data::new(AppState::new(
config.worker_urls.clone(),
client,
config.policy_config.clone(),
));
info!("✅ Starting router on {}:{}", config.host, config.port);
info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
info!("✅ Policy Config: {:?}", config.policy_config);
info!(
"✅ Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
);
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.app_data(web::JsonConfig::default().limit(config.max_payload_size))
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
.service(v1_models)
.service(get_model_info)
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
.service(remove_worker)
})
.bind((config.host, config.port))?
.run()
.await
}

1483
sgl-router/src/tree.rs Normal file

File diff suppressed because it is too large Load Diff

63
sgl-router/v0.1.0.md Normal file
View File

@@ -0,0 +1,63 @@
# SGLang Router v0.1.0: Dynamic Scaling and Fault Tolerance
We have released `sglang-router` v0.1.0 equipped with dynamic scaling and fault tolerance! It is essential for the router to be able to dynamically scale the number of workers and handle worker failures. To achieve this, we have implemented the following features:
## 1. Dynamic scaling: The router can dynamically scale the number of workers based on the request load.
We offer `/add_worker` and `/remove_worker` APIs to dynamically add or remove workers from the router.
- `/add_worker`
Usage:
```bash
$ curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1
```
Example:
```bash
$ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001
$ curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001
Successfully added worker: http://127.0.0.1:30001
```
- `/remove_worker`
Usage:
```bash
$ curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1
```
Example:
```bash
$ curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001
Successfully removed worker: http://127.0.0.1:30001
```
Note:
- For cache-aware router, the worker will be removed from the tree and the queues.
## 2. Fault tolerance: The router can handle worker failures and automatically remove the failed worker from the router.
We provide retries based for failure tolerance.
1. If the request to a worker fails for `max_worker_retries` times, the router will remove the worker from the router and move on to the next worker.
2. If the total number of retries exceeds `max_total_retries`, the router will return an error.
Note:
- `max_worker_retries` is 3 and `max_total_retries` is 6 by default.
## Closing remarks:
1. Please read the full usage at https://sgl-project.github.io/router/router.html
2. The feature is still under active improvement, so please don't hesitate to raise issues or submit PRs if you have any suggestions or feedback.
# Release Instructions
Update the version in `rust/pyproject.toml` and `py_src/sglang_router/version.py`.