[router] Replace print with logger (#2183)

This commit is contained in:
Byron Hsu
2024-11-25 13:36:02 -08:00
committed by GitHub
parent e1e595d702
commit 4d62bca542
10 changed files with 249 additions and 47 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import dataclasses
import logging
import sys
from typing import List, Optional
@@ -7,6 +8,22 @@ from sglang_router import Router
from sglang_router_rs import PolicyType
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
@@ -21,6 +38,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
verbose: bool = False
@staticmethod
def add_cli_args(
@@ -98,6 +116,11 @@ class RouterArgs:
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}verbose",
action="store_true",
help="Enable verbose logging",
)
@classmethod
def from_cli_args(
@@ -121,6 +144,7 @@ class RouterArgs:
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
verbose=getattr(args, f"{prefix}verbose", False),
)
@@ -145,6 +169,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
Returns:
Router instance if successful, None if failed
"""
logger = logging.getLogger("router")
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
@@ -162,13 +187,14 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
verbose=router_args.verbose,
)
router.start()
return router
except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr)
logger.error(f"Error starting router: {e}", file=sys.stderr)
return None
@@ -202,6 +228,7 @@ Examples:
def main() -> None:
logger = setup_logger()
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)

View File

@@ -1,5 +1,6 @@
import argparse
import copy
import logging
import multiprocessing as mp
import os
import random
@@ -17,6 +18,22 @@ from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
@@ -42,20 +59,20 @@ def launch_server_process(
def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups."""
print("\nCleaning up processes...")
logger = logging.getLogger("router")
logger.info("Cleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive():
logger.warning(
f"Process {proc.pid} did not terminate gracefully, force killing..."
)
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass # Process already terminated
pass
def setup_signal_handlers(cleanup_func):
@@ -101,6 +118,8 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def main():
logger = setup_logger()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
@@ -130,8 +149,8 @@ def main():
server_processes = []
try:
# Launch server processes
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
@@ -140,18 +159,19 @@ def main():
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy")
logger.error(f"Server on port {port} failed to become healthy")
all_healthy = False
break
if not all_healthy:
print("Not all servers are healthy. Shutting down...")
logger.error("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
print("All servers are healthy. Starting router...")
logger.info("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
@@ -162,16 +182,17 @@ def main():
router = launch_router(router_args)
if router is None:
print("Failed to start router. Shutting down...")
logger.error("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
print("\nReceived shutdown signal...")
logger.info("Received shutdown signal...")
except Exception as e:
print(f"Error occurred: {e}")
print(get_exception_traceback())
logger.error(f"Error occurred: {e}")
logger.error(get_exception_traceback())
finally:
logger.info("Cleaning up processes...")
cleanup_processes(server_processes)

View File

@@ -27,6 +27,7 @@ class Router:
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
"""
def __init__(
@@ -40,6 +41,7 @@ class Router:
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
verbose: bool = False,
):
self._router = _Router(
worker_urls=worker_urls,
@@ -51,6 +53,7 @@ class Router:
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
verbose=verbose,
)
def start(self) -> None: