[router] cache-aware load-balancing router v1 (#2114)

This commit is contained in:
Byron Hsu
2024-11-23 08:34:48 -08:00
committed by GitHub
parent ad47749b82
commit cbedd1db1d
17 changed files with 1963 additions and 602 deletions

View File

@@ -0,0 +1,204 @@
import argparse
import dataclasses
import sys
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str]
host: str = "127.0.0.1"
port: int = 30000
# Routing policy
policy: str = "cache_aware"
cache_threshold: float = 0.5
cache_routing_prob: float = 1.0
eviction_interval: int = 60
max_tree_size: int = 2**24
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="+",
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}cache-routing-prob",
type=float,
default=RouterArgs.cache_routing_prob,
help="Probability of using cache-aware routing (0.0-1.0)",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
return cls(
worker_urls=args.worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
)
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
router = Router(
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
cache_threshold=router_args.cache_threshold,
cache_routing_prob=router_args.cache_routing_prob,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
)
router.start()
return router
except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr)
return None
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,178 @@
import argparse
import copy
import multiprocessing as mp
import os
import signal
import sys
import time
from typing import List
import requests
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs, prepare_server_args
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
# Set DP_RANK environment variable
os.environ["DP_RANK"] = str(dp_rank)
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups."""
print("\nCleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive():
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass # Process already terminated
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += 1
return available_ports
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
try:
# Launch server processes
for i, worker_port in enumerate(worker_ports):
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy")
all_healthy = False
break
if not all_healthy:
print("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
print("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
print("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
print("\nReceived shutdown signal...")
except Exception as e:
print(f"Error occurred: {e}")
print(get_exception_traceback())
finally:
cleanup_processes(server_processes)
if __name__ == "__main__":
main()

View File

@@ -9,16 +9,23 @@ class Router:
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity
host: Host address to bind the router server
port: Port number to bind the router server
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy)
cache_threshold: Caching threshold value between 0-1
- PolicyType.CacheAware: Distribute requests in cache-aware fashion
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven
workloads, use a lower value to better distribute requests
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
"""
def __init__(
@@ -27,17 +34,20 @@ class Router:
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
tokenizer_path: Optional[str] = None,
cache_threshold: float = 0.50,
cache_routing_prob: float = 1.0,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
):
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
tokenizer_path=tokenizer_path,
cache_threshold=cache_threshold,
cache_routing_prob=cache_routing_prob,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
)
def start(self) -> None: