setup router python binding ci (#1999)

This commit is contained in:
Byron Hsu
2024-11-11 12:19:32 -08:00
committed by GitHub
parent ddeb9d42de
commit 00ffde206f
13 changed files with 254 additions and 161 deletions

5
rust/py_src/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
# a lightweihgt wrapper on router with argument type and comments
# no wrapper on policy type => direct export
from sglang_router_rs import PolicyType
from .router import Router

View File

@@ -1,156 +0,0 @@
import argparse
import os
import signal
import subprocess
import sys
import time
from typing import Dict, List
import requests
from sglang_router import PolicyType, Router
# Global processes list for cleanup
_processes: List[subprocess.Popen] = []
def cleanup_processes(signum=None, frame=None):
"""Cleanup function to kill all worker processes."""
print("\nCleaning up processes...")
for process in _processes:
try:
# Kill the entire process group
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGKILL)
process.wait()
except:
pass
sys.exit(1)
# Register signal handlers
signal.signal(signal.SIGINT, cleanup_processes)
signal.signal(signal.SIGTERM, cleanup_processes)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
parser.add_argument(
"--host", type=str, default="localhost", help="Host address to bind the server"
)
parser.add_argument(
"--port", type=int, default=30000, help="Base port number for workers"
)
parser.add_argument(
"--dp",
type=int,
default=2,
help="Number of worker processes (degree of parallelism)",
)
parser.add_argument(
"--model-path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
"--local-tokenizer-path",
type=str,
required=True,
help="Path to the local tokenizer",
)
return parser.parse_args()
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
"""Launch all worker processes concurrently using subprocess."""
processes = []
worker_urls = []
# Launch each worker process
for i in range(args.dp):
port = args.port + i
url = f"http://{args.host}:{port}"
worker_urls.append(url)
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
# We don't
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
print(command)
process = subprocess.Popen(command, shell=True)
processes.append(process)
_processes.append(process) # Add to global list for cleanup
return processes, worker_urls
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
"""Block until all workers are healthy or timeout is reached."""
start_time = time.time()
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
while time.time() - start_time < timeout:
print("checking healthiness...")
all_healthy = True
for url in worker_urls:
if not healthy_workers[url]: # Only check workers that aren't healthy yet
try:
response = requests.get(f"{url}/health")
if response.status_code == 200:
print(f"Worker at {url} is healthy")
healthy_workers[url] = True
else:
all_healthy = False
except requests.RequestException:
all_healthy = False
if all_healthy:
print("All workers are healthy!")
return True
time.sleep(5)
# If we get here, we've timed out
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
print(f"Timeout waiting for workers: {unhealthy_workers}")
return False
def main():
"""Main function to launch the router and workers."""
args = parse_args()
processes = None
try:
# Launch all workers concurrently
processes, worker_urls = launch_workers(args)
# Block until all workers are healthy
if not wait_for_healthy_workers(worker_urls):
raise RuntimeError("Failed to start all workers")
# Initialize and start the router
router = Router(
worker_urls=worker_urls,
policy=PolicyType.ApproxTree,
tokenizer_path=args.local_tokenizer_path,
)
print("Starting router...")
router.start()
# Keep the main process running
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup: Kill all worker processes
if processes:
for process in processes:
process.kill()
if __name__ == "__main__":
main()

View File

@@ -1,12 +0,0 @@
from sglang_router import PolicyType, Router
router = Router(
worker_urls=[
"http://localhost:30000",
"http://localhost:30001",
],
policy=PolicyType.ApproxTree,
tokenizer_path="/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693/tokenizer.json",
)
router.start()

48
rust/py_src/router.py Normal file
View File

@@ -0,0 +1,48 @@
from typing import List, Optional
from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router
class Router:
"""
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity
host: Host address to bind the router server
port: Port number to bind the router server
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy)
cache_threshold: Caching threshold value between 0-1
"""
def __init__(
self,
worker_urls: List[str],
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
tokenizer_path: Optional[str] = None,
cache_threshold: float = 0.50,
):
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
tokenizer_path=tokenizer_path,
cache_threshold=cache_threshold,
)
def start(self) -> None:
"""Start the router server.
This method blocks until the server is shut down.
"""
self._router.start()