[router] Improve cleanup logic (#2411)

This commit is contained in:
Byron Hsu
2024-12-08 15:24:02 -08:00
committed by GitHub
parent a6ca736c8e
commit a1e697b25b
2 changed files with 78 additions and 92 deletions

View File

@@ -10,12 +10,12 @@ import time
from typing import List
import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
def setup_logger():
@@ -34,10 +34,12 @@ def setup_logger():
return logger
logger = setup_logger()
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
setproctitle(f"sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
@@ -58,36 +60,6 @@ def launch_server_process(
return proc
def cleanup_processes(processes: List[mp.Process]):
logger = logging.getLogger("router")
logger.info("Cleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.join(timeout=3)
if proc.is_alive():
logger.warning(
f"Process {proc.pid} did not terminate gracefully, force killing..."
)
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return available_ports
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
process.terminate()
def main():
logger = setup_logger()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
@@ -148,52 +124,33 @@ def main():
# Start server processes
server_processes = []
try:
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
signal.signal(
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
signal.signal(
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
logger.error(f"Server on port {port} failed to become healthy")
break
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
logger.error(f"Server on port {port} failed to become healthy")
all_healthy = False
break
logger.info("All servers are healthy. Starting router...")
if not all_healthy:
logger.error("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
logger.info("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
logger.error("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
logger.info("Received shutdown signal...")
except Exception as e:
logger.error(f"Error occurred: {e}")
logger.error(get_exception_traceback())
finally:
logger.info("Cleaning up processes...")
cleanup_processes(server_processes)
# Start the router
router = launch_router(router_args)
if __name__ == "__main__":