Rename rust folder to sgl-router (#2464)
Signed-off-by: Ata Fatahi <immrata@gmail.com>
This commit is contained in:
3106
sgl-router/Cargo.lock
generated
Normal file
3106
sgl-router/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
sgl-router/Cargo.toml
Normal file
32
sgl-router/Cargo.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "sglang_router_rs"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "sglang_router_rs"
|
||||
# Pure Rust library: Just omit crate-type (defaults to rlib)
|
||||
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
actix-web = "4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
bytes = "1.8.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
|
||||
futures-util = "0.3"
|
||||
serde_json = "1.0"
|
||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||
tokenizers = { version = "0.20.3", features = ["http"] }
|
||||
dashmap = "6.1.0"
|
||||
http = "1.1.0"
|
||||
env_logger = "0.11.5"
|
||||
log = "0.4.22"
|
||||
chrono = "0.4.38"
|
||||
tokio = "1.42.0"
|
||||
|
||||
[profile.release]
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
3
sgl-router/MANIFEST.in
Normal file
3
sgl-router/MANIFEST.in
Normal file
@@ -0,0 +1,3 @@
|
||||
# Must include:
|
||||
include Cargo.toml # Rust project configuration
|
||||
recursive-include src *.rs # Rust source files
|
||||
87
sgl-router/README.md
Normal file
87
sgl-router/README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# SGLang Router
|
||||
|
||||
SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances.
|
||||
|
||||
## User docs
|
||||
|
||||
Please check https://sgl-project.github.io/router/router.html
|
||||
|
||||
## Developer docs
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust and Cargo installed
|
||||
|
||||
```bash
|
||||
# Install rustup (Rust installer and version manager)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
|
||||
# Follow the installation prompts, then reload your shell
|
||||
source $HOME/.cargo/env
|
||||
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
```
|
||||
|
||||
- Python with pip installed
|
||||
|
||||
|
||||
### Build Process
|
||||
|
||||
#### 1. Build Rust Project
|
||||
|
||||
```bash
|
||||
$ cargo build
|
||||
```
|
||||
|
||||
#### 2. Build Python Binding
|
||||
|
||||
##### Option A: Build and Install Wheel
|
||||
1. Build the wheel package:
|
||||
```bash
|
||||
$ pip install setuptools-rust wheel build
|
||||
$ python -m build
|
||||
```
|
||||
|
||||
2. Install the generated wheel:
|
||||
```bash
|
||||
$ pip install <path-to-wheel>
|
||||
```
|
||||
|
||||
If you want one handy command to do build + install for every change you make:
|
||||
|
||||
```bash
|
||||
$ python -m build && pip install --force-reinstall dist/*.whl
|
||||
```
|
||||
|
||||
##### Option B: Development Mode
|
||||
|
||||
For development purposes, you can install the package in editable mode:
|
||||
|
||||
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
|
||||
|
||||
```bash
|
||||
$ pip install -e .
|
||||
```
|
||||
|
||||
**Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect.
|
||||
|
||||
### CI/CD Setup
|
||||
|
||||
The continuous integration pipeline consists of three main steps:
|
||||
|
||||
#### 1. Build Wheels
|
||||
- Uses `cibuildwheel` to create manylinux x86_64 packages
|
||||
- Compatible with major Linux distributions (Ubuntu, CentOS, etc.)
|
||||
- Additional configurations can be added to support other OS/architectures
|
||||
- Reference: [cibuildwheel documentation](https://cibuildwheel.pypa.io/en/stable/)
|
||||
|
||||
#### 2. Build Source Distribution
|
||||
- Creates a source distribution containing the raw, unbuilt code
|
||||
- Enables `pip` to build the package from source when prebuilt wheels are unavailable
|
||||
|
||||
#### 3. Publish to PyPI
|
||||
- Uploads both wheels and source distribution to PyPI
|
||||
|
||||
The CI configuration is based on the [tiktoken workflow](https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/.github/workflows/build_wheels.yml#L1).
|
||||
11
sgl-router/py_src/sglang_router/__init__.py
Normal file
11
sgl-router/py_src/sglang_router/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# a lightweihgt wrapper on router with argument type and comments
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
# no wrapper on policy type => direct export
|
||||
from .router import Router
|
||||
|
||||
__all__ = ["Router", "PolicyType"]
|
||||
|
||||
from sglang_router.version import __version__
|
||||
|
||||
__all__ += ["__version__"]
|
||||
249
sgl-router/py_src/sglang_router/launch_router.py
Normal file
249
sgl-router/py_src/sglang_router/launch_router.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router import Router
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str]
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
cache_threshold: float = 0.5
|
||||
balance_abs_threshold: int = 32
|
||||
balance_rel_threshold: float = 1.0001
|
||||
eviction_interval: int = 60
|
||||
max_tree_size: int = 2**24
|
||||
max_payload_size: int = 4 * 1024 * 1024 # 4MB
|
||||
verbose: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware"],
|
||||
help="Load balancing policy to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
return cls(
|
||||
worker_urls=args.worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
||||
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
|
||||
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||
verbose=getattr(args, f"{prefix}verbose", False),
|
||||
)
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
Launch the SGLang router with the configuration from parsed arguments.
|
||||
|
||||
Args:
|
||||
args: Namespace object containing router configuration
|
||||
Can be either raw argparse.Namespace or converted RouterArgs
|
||||
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
logger = logging.getLogger("router")
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
router = Router(
|
||||
worker_urls=router_args.worker_urls,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
cache_threshold=router_args.cache_threshold,
|
||||
balance_abs_threshold=router_args.balance_abs_threshold,
|
||||
balance_rel_threshold=router_args.balance_rel_threshold,
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
max_payload_size=router_args.max_payload_size,
|
||||
verbose=router_args.verbose,
|
||||
)
|
||||
|
||||
router.start()
|
||||
return router
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class CustomHelpFormatter(
|
||||
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
|
||||
):
|
||||
"""Custom formatter that preserves both description formatting and shows defaults"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def parse_router_args(args: List[str]) -> RouterArgs:
|
||||
"""Parse command line arguments and return RouterArgs instance."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""SGLang Router - High-performance request distribution across worker nodes
|
||||
|
||||
Usage:
|
||||
This launcher enables starting a router with individual worker instances. It is useful for
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
)
|
||||
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=False)
|
||||
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logger = setup_logger()
|
||||
router_args = parse_router_args(sys.argv[1:])
|
||||
router = launch_router(router_args)
|
||||
|
||||
if router is None:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
180
sgl-router/py_src/sglang_router/launch_server.py
Normal file
180
sgl-router/py_src/sglang_router/launch_server.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from setproctitle import setproctitle
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
from sglang.srt.server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_port_available
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Create new process group
|
||||
def run_server(server_args, dp_rank):
|
||||
"""
|
||||
Note:
|
||||
|
||||
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
|
||||
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=100)
|
||||
└── Server Process 1 (PGID=100)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=100)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
|
||||
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=200)
|
||||
└── Server Process 1 (PGID=300)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=400)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
"""
|
||||
# create new process group
|
||||
os.setpgrp()
|
||||
|
||||
setproctitle(f"sglang::server")
|
||||
# Set SGLANG_DP_RANK environment variable
|
||||
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
||||
|
||||
launch_server(server_args)
|
||||
|
||||
|
||||
def launch_server_process(
|
||||
server_args: ServerArgs, worker_port: int, dp_id: int
|
||||
) -> mp.Process:
|
||||
"""Launch a single server process with the given args and port."""
|
||||
server_args = copy.deepcopy(server_args)
|
||||
server_args.port = worker_port
|
||||
server_args.base_gpu_id = dp_id * server_args.tp_size
|
||||
server_args.dp_size = 1
|
||||
|
||||
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
||||
"""Wait for server to be healthy by checking /health endpoint."""
|
||||
start_time = time.time()
|
||||
url = f"http://{host}:{port}/health"
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
"""Find consecutive available ports starting from base_port."""
|
||||
available_ports = []
|
||||
current_port = base_port
|
||||
|
||||
while len(available_ports) < count:
|
||||
if is_port_available(current_port):
|
||||
available_ports.append(current_port)
|
||||
current_port += random.randint(100, 1000)
|
||||
|
||||
return available_ports
|
||||
|
||||
|
||||
def cleanup_processes(processes: List[mp.Process]):
|
||||
for process in processes:
|
||||
logger.info(f"Terminating process {process.pid}")
|
||||
process.terminate()
|
||||
logger.info("All processes terminated")
|
||||
|
||||
|
||||
def main():
|
||||
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch SGLang router and server processes"
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
||||
parser.add_argument(
|
||||
"--router-dp-worker-base-port",
|
||||
type=int,
|
||||
default=31000,
|
||||
help="Base port number for data parallel workers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Find available ports for workers
|
||||
worker_ports = find_available_ports(
|
||||
args.router_dp_worker_base_port, server_args.dp_size
|
||||
)
|
||||
|
||||
# Start server processes
|
||||
server_processes = []
|
||||
|
||||
for i, worker_port in enumerate(worker_ports):
|
||||
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
||||
proc = launch_server_process(server_args, worker_port, i)
|
||||
server_processes.append(proc)
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
||||
signal.signal(
|
||||
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
signal.signal(
|
||||
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
|
||||
# Update router args with worker URLs
|
||||
router_args.worker_urls = [
|
||||
f"http://{server_args.host}:{port}" for port in worker_ports
|
||||
]
|
||||
|
||||
# Start the router
|
||||
router = launch_router(router_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
sgl-router/py_src/sglang_router/router.py
Normal file
67
sgl-router/py_src/sglang_router/router.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
|
||||
Args:
|
||||
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
||||
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
||||
policy: Load balancing policy to use. Options:
|
||||
- PolicyType.Random: Randomly select workers
|
||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
||||
port: Port number to bind the router server. Default: 3001
|
||||
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
||||
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
||||
tree. Default: 0.5
|
||||
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
|
||||
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||
routing. Default: 60
|
||||
max_payload_size: Maximum payload size in bytes. Default: 4MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
verbose: Enable verbose logging. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_urls: List[str],
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
cache_threshold: float = 0.50,
|
||||
balance_abs_threshold: int = 32,
|
||||
balance_rel_threshold: float = 1.0001,
|
||||
eviction_interval_secs: int = 60,
|
||||
max_tree_size: int = 2**24,
|
||||
max_payload_size: int = 4 * 1024 * 1024, # 4MB
|
||||
verbose: bool = False,
|
||||
):
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
cache_threshold=cache_threshold,
|
||||
balance_abs_threshold=balance_abs_threshold,
|
||||
balance_rel_threshold=balance_rel_threshold,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
This method blocks until the server is shut down.
|
||||
"""
|
||||
self._router.start()
|
||||
1
sgl-router/py_src/sglang_router/version.py
Normal file
1
sgl-router/py_src/sglang_router/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.1"
|
||||
19
sgl-router/py_test/run_suite.py
Normal file
19
sgl-router/py_test/run_suite.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
from sglang.test.test_utils import run_unittest_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
files = glob.glob("**/test_*.py", recursive=True)
|
||||
|
||||
exit_code = run_unittest_files(files, args.timeout_per_file)
|
||||
exit(exit_code)
|
||||
67
sgl-router/py_test/test_launch_router.py
Normal file
67
sgl-router/py_test/test_launch_router.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
|
||||
"""Terminate a process gracefully, with forced kill as fallback.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout: Seconds to wait for graceful termination before forcing kill
|
||||
"""
|
||||
if not process.is_alive():
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
process.join(timeout=timeout)
|
||||
if process.is_alive():
|
||||
process.kill() # Force kill if terminate didn't work
|
||||
process.join()
|
||||
|
||||
|
||||
class TestLaunchRouter(unittest.TestCase):
|
||||
def test_launch_router_no_exception(self):
|
||||
|
||||
# Create SimpleNamespace with default arguments
|
||||
args = SimpleNamespace(
|
||||
worker_urls=["http://localhost:8000"],
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
policy="cache_aware",
|
||||
cache_threshold=0.5,
|
||||
balance_abs_threshold=32,
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=4 * 1024 * 1024, # 4MB
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def run_router():
|
||||
try:
|
||||
from sglang_router.launch_router import launch_router
|
||||
|
||||
router = launch_router(args)
|
||||
if router is None:
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 1
|
||||
|
||||
# Start router in separate process
|
||||
process = multiprocessing.Process(target=run_router)
|
||||
try:
|
||||
process.start()
|
||||
# Wait 3 seconds
|
||||
time.sleep(3)
|
||||
# Process is still running means router started successfully
|
||||
self.assertTrue(process.is_alive())
|
||||
finally:
|
||||
terminate_process(process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
338
sgl-router/py_test/test_launch_server.py
Normal file
338
sgl-router/py_test/test_launch_server.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
)
|
||||
|
||||
|
||||
def popen_launch_router(
|
||||
model: str,
|
||||
base_url: str,
|
||||
dp_size: int,
|
||||
timeout: float,
|
||||
policy: str = "cache_aware",
|
||||
max_payload_size: int = None,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
|
||||
Args:
|
||||
model: Model path/name
|
||||
base_url: Server base URL
|
||||
dp_size: Data parallel size
|
||||
timeout: Server launch timeout
|
||||
policy: Router policy, one of "cache_aware", "round_robin", "random"
|
||||
max_payload_size: Maximum payload size in bytes
|
||||
"""
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-eviction-interval",
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
]
|
||||
|
||||
if max_payload_size is not None:
|
||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.time()
|
||||
with requests.Session() as session:
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = session.get(f"{base_url}/health")
|
||||
if response.status_code == 200:
|
||||
print(f"Router {base_url} is healthy")
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
raise TimeoutError("Router failed to start within the timeout period.")
|
||||
|
||||
|
||||
def find_available_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
timeout: float,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--base-gpu-id",
|
||||
"1",
|
||||
]
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
# intentionally don't wait and defer the job to the router health check
|
||||
return process
|
||||
|
||||
|
||||
def terminate_and_wait(process, timeout=300):
|
||||
"""Terminate a process and wait until it is terminated.
|
||||
|
||||
Args:
|
||||
process: subprocess.Popen object
|
||||
timeout: maximum time to wait in seconds
|
||||
|
||||
Raises:
|
||||
TimeoutError: if process does not terminate within timeout
|
||||
"""
|
||||
if process is None:
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
start_time = time.time()
|
||||
|
||||
while process.poll() is None:
|
||||
print(f"Terminating process {process.pid}")
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Process {process.pid} failed to terminate within {timeout}s"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
print(f"Process {process.pid} is successfully terminated")
|
||||
|
||||
|
||||
class TestLaunchServer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
self.base_url = DEFAULT_URL_FOR_TEST
|
||||
self.process = None
|
||||
self.other_process = []
|
||||
|
||||
def tearDown(self):
|
||||
print("Running tearDown...")
|
||||
if self.process:
|
||||
terminate_and_wait(self.process)
|
||||
for process in self.other_process:
|
||||
terminate_and_wait(process)
|
||||
print("tearDown done")
|
||||
|
||||
def test_1_mmlu(self):
|
||||
print("Running test_1_mmlu...")
|
||||
# DP size = 2
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=2,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_2_add_and_remove_worker(self):
|
||||
print("Running test_2_add_and_remove_worker...")
|
||||
# DP size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin", # use round robin to make sure every worker processes requests
|
||||
)
|
||||
# 1. start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 3. run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 4. use /remove_worker api to remove it from the router
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 5. run mmlu again
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_3_lazy_fault_tolerance(self):
|
||||
print("Running test_3_lazy_fault_tolerance...")
|
||||
# DP size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
)
|
||||
|
||||
# 1. start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure
|
||||
def kill_worker():
|
||||
time.sleep(10)
|
||||
kill_process_tree(worker_process.pid)
|
||||
print("Worker process killed")
|
||||
|
||||
import threading
|
||||
|
||||
kill_thread = threading.Thread(target=kill_worker)
|
||||
kill_thread.daemon = True
|
||||
kill_thread.start()
|
||||
|
||||
# 3. run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=256,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_4_payload_size(self):
|
||||
print("Running test_4_payload_size...")
|
||||
# Start router with 3MB limit
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||
)
|
||||
|
||||
# Test case 1: Payload just under 1MB should succeed
|
||||
payload_0_5_mb = {
|
||||
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_0_5_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: Payload over 1MB should fail
|
||||
payload_1_plus_mb = {
|
||||
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_1_plus_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
413, # Payload Too Large
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
26
sgl-router/pyproject.toml
Normal file
26
sgl-router/pyproject.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel", "setuptools-rust>=1.5.2"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang-router"
|
||||
version = "0.1.1"
|
||||
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
|
||||
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
|
||||
requires-python = ">=3.8"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Rust",
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
|
||||
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
|
||||
[tool.setuptools.packages]
|
||||
find = { where = ["py_src"] }
|
||||
|
||||
[[tool.setuptools-rust.ext-modules]]
|
||||
target = "sglang_router_rs"
|
||||
path = "Cargo.toml"
|
||||
binding = "PyO3"
|
||||
108
sgl-router/src/lib.rs
Normal file
108
sgl-router/src/lib.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use pyo3::prelude::*;
|
||||
pub mod router;
|
||||
pub mod server;
|
||||
pub mod tree;
|
||||
|
||||
#[pyclass(eq)]
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Router {
|
||||
host: String,
|
||||
port: u16,
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Router {
|
||||
#[new]
|
||||
#[pyo3(signature = (
|
||||
worker_urls,
|
||||
policy = PolicyType::RoundRobin,
|
||||
host = String::from("127.0.0.1"),
|
||||
port = 3001,
|
||||
cache_threshold = 0.50,
|
||||
balance_abs_threshold = 32,
|
||||
balance_rel_threshold = 1.0001,
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 4 * 1024 * 1024,
|
||||
verbose = false
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
host: String,
|
||||
port: u16,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
verbose: bool,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
port,
|
||||
worker_urls,
|
||||
policy,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
max_payload_size,
|
||||
verbose,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig,
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
};
|
||||
|
||||
actix_web::rt::System::new().block_on(async move {
|
||||
server::startup(server::ServerConfig {
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
policy_config,
|
||||
verbose: self.verbose,
|
||||
max_payload_size: self.max_payload_size,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PolicyType>()?;
|
||||
m.add_class::<Router>()?;
|
||||
Ok(())
|
||||
}
|
||||
702
sgl-router/src/router.rs
Normal file
702
sgl-router/src/router.rs
Normal file
@@ -0,0 +1,702 @@
|
||||
use crate::tree::Tree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use log::{debug, info, warn};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tokio;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Router {
|
||||
RoundRobin {
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
current_index: AtomicUsize,
|
||||
},
|
||||
Random {
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
},
|
||||
CacheAware {
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
|
||||
This router combines two strategies to optimize both cache utilization and request distribution:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
2. Load Balancing (Shortest Queue with Balance Thresholds)
|
||||
|
||||
The router dynamically switches between these strategies based on load conditions:
|
||||
- Uses load balancing when the system is imbalanced
|
||||
- Uses cache-aware routing when the system is balanced
|
||||
|
||||
A system is considered imbalanced if both conditions are met:
|
||||
1. (max - min) > abs_threshold
|
||||
2. max > rel_threshold * min
|
||||
|
||||
Strategy Details:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
-------------------------------------------
|
||||
This strategy maintains an approximate radix tree for each worker based on request history,
|
||||
eliminating the need for direct cache state queries. The tree stores raw text characters
|
||||
instead of token IDs to avoid tokenization overhead.
|
||||
|
||||
Process:
|
||||
a. For each request, find the worker with the highest prefix match
|
||||
b. If match rate > cache_threshold:
|
||||
Route to the worker with highest match (likely has relevant data cached)
|
||||
c. If match rate ≤ cache_threshold:
|
||||
Route to the worker with smallest tree size (most available cache capacity)
|
||||
d. Background maintenance:
|
||||
Periodically evict least recently used leaf nodes to prevent memory overflow
|
||||
|
||||
2. Load Balancing (Shortest Queue)
|
||||
-------------------------------------------
|
||||
This strategy tracks pending request counts per worker and routes new requests
|
||||
to the least busy worker when the system is detected to be imbalanced.
|
||||
|
||||
Configuration Parameters:
|
||||
------------------------
|
||||
1. cache_threshold: (float, 0.0 to 1.0)
|
||||
Minimum prefix match ratio to use highest-match routing.
|
||||
Below this threshold, routes to worker with most available cache space.
|
||||
|
||||
2. balance_abs_threshold: (integer)
|
||||
Absolute difference threshold for load imbalance detection.
|
||||
System is potentially imbalanced if (max_load - min_load) > abs_threshold
|
||||
|
||||
3. balance_rel_threshold: (float)
|
||||
Relative ratio threshold for load imbalance detection.
|
||||
System is potentially imbalanced if max_load > min_load * rel_threshold
|
||||
Used in conjunction with abs_threshold to determine final imbalance state.
|
||||
|
||||
4. eviction_interval_secs: (integer)
|
||||
Interval between LRU eviction cycles for the approximate trees.
|
||||
|
||||
5. max_tree_size: (integer)
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
running_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
_eviction_thread: Option<thread::JoinHandle<()>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PolicyConfig {
|
||||
RandomConfig,
|
||||
RoundRobinConfig,
|
||||
CacheAwareConfig {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl Router {
|
||||
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
|
||||
// Wait until all workers are healthy
|
||||
Self::wait_for_healthy_workers(&worker_urls, 300, 10)?;
|
||||
|
||||
// Create router based on policy...
|
||||
Ok(match policy_config {
|
||||
PolicyConfig::RandomConfig => Router::Random {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
},
|
||||
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||
},
|
||||
PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let mut running_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
running_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
let mut processed_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
processed_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
let running_queue = Arc::new(Mutex::new(running_queue));
|
||||
let processed_queue = Arc::new(Mutex::new(processed_queue));
|
||||
|
||||
// Create background eviction thread
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let processed_queue_clone = Arc::clone(&processed_queue);
|
||||
let running_queue_clone = Arc::clone(&running_queue);
|
||||
let eviction_thread = thread::spawn(move || {
|
||||
loop {
|
||||
// Sleep for the specified interval
|
||||
thread::sleep(Duration::from_secs(eviction_interval_secs));
|
||||
|
||||
let locked_tree_clone = tree_clone.lock().unwrap();
|
||||
// Run eviction
|
||||
locked_tree_clone.evict_tenant_by_size(max_tree_size);
|
||||
|
||||
// Print the process queue
|
||||
let locked_processed_queue = processed_queue_clone.lock().unwrap();
|
||||
info!("Processed Queue: {:?}", locked_processed_queue);
|
||||
|
||||
// Print the running queue
|
||||
let locked_running_queue = running_queue_clone.lock().unwrap();
|
||||
info!("Running Queue: {:?}", locked_running_queue);
|
||||
}
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
return Err(format!(
|
||||
"Timeout {}s waiting for workers to become healthy",
|
||||
timeout_secs
|
||||
));
|
||||
}
|
||||
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_workers = Vec::new();
|
||||
|
||||
for url in worker_urls {
|
||||
match sync_client.get(&format!("{}/health", url)).send() {
|
||||
Ok(res) => {
|
||||
if !res.status().is_success() {
|
||||
info!(
|
||||
"Worker {} health check is pending with status: {}.",
|
||||
url,
|
||||
res.status()
|
||||
);
|
||||
all_healthy = false;
|
||||
unhealthy_workers.push((url, format!("Status: {}", res.status())));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Worker {} health check is pending with error: {}", url, e);
|
||||
all_healthy = false;
|
||||
unhealthy_workers.push((url, format!("Error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if all_healthy {
|
||||
info!("All workers are healthy");
|
||||
return Ok(());
|
||||
} else {
|
||||
info!("Unhealthy workers:");
|
||||
for (url, reason) in &unhealthy_workers {
|
||||
info!(" {} - {}", url, reason);
|
||||
}
|
||||
thread::sleep(Duration::from_secs(interval_secs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
if worker_urls.read().unwrap().is_empty() {
|
||||
Err("No workers are available".to_string())
|
||||
} else {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match client.get(format!("{}{}", worker_url, route)).send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to read response body: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError().body(format!(
|
||||
"Failed to send request to worker {}: {}",
|
||||
worker_url, e
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => {
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
}
|
||||
|
||||
let response = self.send_request(client, &worker_url, route).await;
|
||||
|
||||
if response.status().is_success() {
|
||||
return response;
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => return HttpResponse::InternalServerError().body(e),
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
||||
// convert body to json
|
||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||
|
||||
if route == "generate" {
|
||||
// get the "text" field
|
||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return text.to_string();
|
||||
} else if route == "v1/chat/completions" {
|
||||
// get the messages field as raw text
|
||||
if let Some(messages) = json.get("messages") {
|
||||
// Convert messages back to a string, preserving all JSON formatting
|
||||
return serde_json::to_string(messages).unwrap_or_default();
|
||||
}
|
||||
} else if route == "v1/completions" {
|
||||
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return prompt.to_string();
|
||||
}
|
||||
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
// TODO: return Result<String, String> instead of panicking
|
||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
||||
let text = self.get_text_from_request(&body, route);
|
||||
|
||||
let worker_url = match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index,
|
||||
} => {
|
||||
let idx = current_index
|
||||
.fetch_update(
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
|
||||
)
|
||||
.unwrap();
|
||||
worker_urls.read().unwrap()[idx].clone()
|
||||
}
|
||||
|
||||
Router::Random { worker_urls } => worker_urls.read().unwrap()
|
||||
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
|
||||
.clone(),
|
||||
|
||||
Router::CacheAware {
|
||||
worker_urls,
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
..
|
||||
} => {
|
||||
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
|
||||
|
||||
let tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
|
||||
// Get current load statistics
|
||||
let max_load = *running_queue.values().max().unwrap_or(&0);
|
||||
let min_load = *running_queue.values().min().unwrap_or(&0);
|
||||
|
||||
// Load is considered imbalanced if:
|
||||
// 1. (max - min) > abs_threshold AND
|
||||
// 2. max > rel_threshold * min
|
||||
let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
|
||||
&& (max_load as f32) > (min_load as f32 * balance_rel_threshold);
|
||||
|
||||
let selected_url = if is_imbalanced {
|
||||
// Log load balancing trigger and current queue state
|
||||
info!(
|
||||
"Load balancing triggered due to workload imbalance:\n\
|
||||
Max load: {}, Min load: {}\n\
|
||||
Current running queue: {:?}",
|
||||
max_load, min_load, running_queue
|
||||
);
|
||||
|
||||
// Use shortest queue routing when load is imbalanced
|
||||
running_queue
|
||||
.iter()
|
||||
.min_by_key(|(_url, &count)| count)
|
||||
.map(|(url, _)| url.clone())
|
||||
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
|
||||
} else {
|
||||
// Use cache-aware routing when load is balanced
|
||||
let (matched_text, matched_worker) = tree.prefix_match(&text);
|
||||
let matched_rate =
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32;
|
||||
|
||||
if matched_rate > *cache_threshold {
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
tree.get_smallest_tenant()
|
||||
}
|
||||
};
|
||||
|
||||
// Update queues and tree
|
||||
*running_queue.get_mut(&selected_url).unwrap() += 1;
|
||||
|
||||
*processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&selected_url)
|
||||
.unwrap() += 1;
|
||||
tree.insert(&text, &selected_url);
|
||||
|
||||
selected_url
|
||||
}
|
||||
};
|
||||
|
||||
worker_url
|
||||
}
|
||||
|
||||
async fn send_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
) -> HttpResponse {
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
let res = match client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.header(
|
||||
"Content-Type",
|
||||
req.headers()
|
||||
.get("Content-Type")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("application/json"),
|
||||
)
|
||||
.body(body.to_vec())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !is_stream {
|
||||
// For non-streaming requests, get response first
|
||||
let response = match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to get response body: {}", e);
|
||||
HttpResponse::InternalServerError().body(error_msg)
|
||||
}
|
||||
};
|
||||
|
||||
// Then decrement running queue counter if using CacheAware
|
||||
if let Router::CacheAware { running_queue, .. } = self {
|
||||
if let Ok(mut queue) = running_queue.lock() {
|
||||
if let Some(count) = queue.get_mut(worker_url) {
|
||||
*count = count.saturating_sub(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(
|
||||
res.bytes_stream()
|
||||
.map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
})
|
||||
.inspect(move |bytes| {
|
||||
let bytes = bytes.as_ref().unwrap();
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
let mut locked_queue = running_queue.lock().unwrap();
|
||||
let count = locked_queue.get_mut(&worker_url).unwrap();
|
||||
*count = count.saturating_sub(1);
|
||||
debug!("Streaming is done!!")
|
||||
}
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(res.bytes_stream().map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
let worker_url = self.select_generate_worker(body, route);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
}
|
||||
let response = self
|
||||
.send_generate_request(client, req, body, route, &worker_url)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
return response;
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
let interval_secs = 10; // check every 10 seconds
|
||||
let timeout_secs = 300; // 5 minutes
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
return Err(format!(
|
||||
"Timeout {}s waiting for worker {} to become healthy",
|
||||
timeout_secs, worker_url
|
||||
));
|
||||
}
|
||||
|
||||
match client.get(&format!("{}/health", worker_url)).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
info!("Worker {} health check passed", worker_url);
|
||||
let mut urls = worker_urls.write().unwrap();
|
||||
if urls.contains(&worker_url.to_string()) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
urls.push(worker_url.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If cache aware, initialize the queues for the new worker
|
||||
if let Router::CacheAware {
|
||||
running_queue,
|
||||
processed_queue,
|
||||
tree,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
// Add worker to running queue with initial count of 0
|
||||
running_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to processed queue with initial count of 0
|
||||
processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to tree
|
||||
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
|
||||
}
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
} else {
|
||||
info!(
|
||||
"Worker {} health check is pending with status: {}.",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
// if the url does not have http or https prefix, warn users
|
||||
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
|
||||
{
|
||||
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
info!(
|
||||
"Worker {} health check is pending with error: {}",
|
||||
worker_url, e
|
||||
);
|
||||
|
||||
// if the url does not have http or https prefix, warn users
|
||||
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
|
||||
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
let mut urls = worker_urls.write().unwrap();
|
||||
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
|
||||
urls.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if cache aware, remove the worker from the tree
|
||||
if let Router::CacheAware {
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
tree.lock().unwrap().remove_tenant(&worker_url);
|
||||
running_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&worker_url.to_string());
|
||||
processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&worker_url.to_string());
|
||||
info!(
|
||||
"Removed worker from tree and cleaned up queues: {}",
|
||||
worker_url
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
192
sgl-router/src/server.rs
Normal file
192
sgl-router/src/server.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
||||
use bytes::Bytes;
|
||||
use env_logger::Builder;
|
||||
use log::{info, LevelFilter};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
router: Router,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
worker_urls: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Self {
|
||||
// Create router based on policy
|
||||
let router = match Router::new(worker_urls, policy_config) {
|
||||
Ok(router) => router,
|
||||
Err(error) => panic!("Failed to create router: {}", error),
|
||||
};
|
||||
|
||||
Self { router, client }
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
async fn health(data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.route_to_first(&data.client, "/health").await
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate")
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info")
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.route_to_first(&data.client, "/v1/models").await
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info")
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/generate")
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
async fn v1_chat_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
async fn v1_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
async fn add_worker(
|
||||
query: web::Query<HashMap<String, String>>,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
let worker_url = match query.get("url") {
|
||||
Some(url) => url.to_string(),
|
||||
None => {
|
||||
return HttpResponse::BadRequest()
|
||||
.body("Worker URL required. Provide 'url' query parameter")
|
||||
}
|
||||
};
|
||||
|
||||
match data.router.add_worker(&worker_url).await {
|
||||
Ok(message) => HttpResponse::Ok().body(message),
|
||||
Err(error) => HttpResponse::BadRequest().body(error),
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/remove_worker")]
|
||||
async fn remove_worker(
|
||||
query: web::Query<HashMap<String, String>>,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
let worker_url = match query.get("url") {
|
||||
Some(url) => url.to_string(),
|
||||
None => return HttpResponse::BadRequest().finish(),
|
||||
};
|
||||
data.router.remove_worker(&worker_url);
|
||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub worker_urls: Vec<String>,
|
||||
pub policy_config: PolicyConfig,
|
||||
pub verbose: bool,
|
||||
pub max_payload_size: usize,
|
||||
}
|
||||
|
||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
Builder::new()
|
||||
.format(|buf, record| {
|
||||
use chrono::Local;
|
||||
writeln!(
|
||||
buf,
|
||||
"[Router (Rust)] {} - {} - {}",
|
||||
Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.filter(
|
||||
None,
|
||||
if config.verbose {
|
||||
LevelFilter::Debug
|
||||
} else {
|
||||
LevelFilter::Info
|
||||
},
|
||||
)
|
||||
.init();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let app_state = web::Data::new(AppState::new(
|
||||
config.worker_urls.clone(),
|
||||
client,
|
||||
config.policy_config.clone(),
|
||||
));
|
||||
|
||||
info!("✅ Starting router on {}:{}", config.host, config.port);
|
||||
info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
|
||||
info!("✅ Policy Config: {:?}", config.policy_config);
|
||||
info!(
|
||||
"✅ Max payload size: {} MB",
|
||||
config.max_payload_size / (1024 * 1024)
|
||||
);
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(app_state.clone())
|
||||
.app_data(web::JsonConfig::default().limit(config.max_payload_size))
|
||||
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
|
||||
.service(generate)
|
||||
.service(v1_chat_completions)
|
||||
.service(v1_completions)
|
||||
.service(v1_models)
|
||||
.service(get_model_info)
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(get_server_info)
|
||||
.service(add_worker)
|
||||
.service(remove_worker)
|
||||
})
|
||||
.bind((config.host, config.port))?
|
||||
.run()
|
||||
.await
|
||||
}
|
||||
1483
sgl-router/src/tree.rs
Normal file
1483
sgl-router/src/tree.rs
Normal file
File diff suppressed because it is too large
Load Diff
63
sgl-router/v0.1.0.md
Normal file
63
sgl-router/v0.1.0.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# SGLang Router v0.1.0: Dynamic Scaling and Fault Tolerance
|
||||
|
||||
We have released `sglang-router` v0.1.0 equipped with dynamic scaling and fault tolerance! It is essential for the router to be able to dynamically scale the number of workers and handle worker failures. To achieve this, we have implemented the following features:
|
||||
|
||||
## 1. Dynamic scaling: The router can dynamically scale the number of workers based on the request load.
|
||||
|
||||
We offer `/add_worker` and `/remove_worker` APIs to dynamically add or remove workers from the router.
|
||||
|
||||
- `/add_worker`
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
$ curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001
|
||||
$ curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001
|
||||
Successfully added worker: http://127.0.0.1:30001
|
||||
```
|
||||
|
||||
- `/remove_worker`
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
$ curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
$ curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001
|
||||
Successfully removed worker: http://127.0.0.1:30001
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
- For cache-aware router, the worker will be removed from the tree and the queues.
|
||||
|
||||
## 2. Fault tolerance: The router can handle worker failures and automatically remove the failed worker from the router.
|
||||
|
||||
We provide retries based for failure tolerance.
|
||||
|
||||
1. If the request to a worker fails for `max_worker_retries` times, the router will remove the worker from the router and move on to the next worker.
|
||||
2. If the total number of retries exceeds `max_total_retries`, the router will return an error.
|
||||
|
||||
Note:
|
||||
|
||||
- `max_worker_retries` is 3 and `max_total_retries` is 6 by default.
|
||||
|
||||
## Closing remarks:
|
||||
|
||||
1. Please read the full usage at https://sgl-project.github.io/router/router.html
|
||||
2. The feature is still under active improvement, so please don't hesitate to raise issues or submit PRs if you have any suggestions or feedback.
|
||||
|
||||
|
||||
# Release Instructions
|
||||
|
||||
Update the version in `rust/pyproject.toml` and `py_src/sglang_router/version.py`.
|
||||
Reference in New Issue
Block a user