diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index ec86e8b2a..9c482e489 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -10,12 +10,12 @@ import time from typing import List import requests +from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.utils import get_exception_traceback def setup_logger(): @@ -34,10 +34,12 @@ def setup_logger(): return logger +logger = setup_logger() + + # Create new process group def run_server(server_args, dp_rank): - os.setpgrp() # Create new process group - + setproctitle(f"sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -58,36 +60,6 @@ def launch_server_process( return proc -def cleanup_processes(processes: List[mp.Process]): - logger = logging.getLogger("router") - logger.info("Cleaning up processes...") - for proc in processes: - if proc.is_alive(): - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - proc.join(timeout=3) - if proc.is_alive(): - logger.warning( - f"Process {proc.pid} did not terminate gracefully, force killing..." - ) - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except ProcessLookupError: - pass - - -def setup_signal_handlers(cleanup_func): - """Setup handlers for various termination signals.""" - - def signal_handler(signum, frame): - cleanup_func() - sys.exit(1) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, "SIGQUIT"): - signal.signal(signal.SIGQUIT, signal_handler) - - def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: """Wait for server to be healthy by checking /health endpoint.""" start_time = time.time() @@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]: return available_ports +def cleanup_processes(processes: List[mp.Process]): + for process in processes: + process.terminate() + + def main(): - logger = setup_logger() # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes mp.set_start_method("spawn") @@ -148,52 +124,33 @@ def main(): # Start server processes server_processes = [] - try: - for i, worker_port in enumerate(worker_ports): - logger.info(f"Launching DP server process {i} on port {worker_port}") - proc = launch_server_process(server_args, worker_port, i) - server_processes.append(proc) + for i, worker_port in enumerate(worker_ports): + logger.info(f"Launching DP server process {i} on port {worker_port}") + proc = launch_server_process(server_args, worker_port, i) + server_processes.append(proc) - # Setup cleanup handler - setup_signal_handlers(lambda: cleanup_processes(server_processes)) + signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes)) + signal.signal( + signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes) + ) + signal.signal( + signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) + ) - # Wait for all servers to be healthy - all_healthy = True + for port in worker_ports: + if not wait_for_server_health(server_args.host, port): + logger.error(f"Server on port {port} failed to become healthy") + break - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - all_healthy = False - break + logger.info("All servers are healthy. Starting router...") - if not all_healthy: - logger.error("Not all servers are healthy. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) + # Update router args with worker URLs + router_args.worker_urls = [ + f"http://{server_args.host}:{port}" for port in worker_ports + ] - logger.info("All servers are healthy. Starting router...") - - # Update router args with worker URLs - router_args.worker_urls = [ - f"http://{server_args.host}:{port}" for port in worker_ports - ] - - # Start the router - router = launch_router(router_args) - - if router is None: - logger.error("Failed to start router. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - except KeyboardInterrupt: - logger.info("Received shutdown signal...") - except Exception as e: - logger.error(f"Error occurred: {e}") - logger.error(get_exception_traceback()) - finally: - logger.info("Cleaning up processes...") - cleanup_processes(server_processes) + # Start the router + router = launch_router(router_args) if __name__ == "__main__": diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 68945d8fb..2591abb5c 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -6,7 +6,6 @@ from types import SimpleNamespace import requests -from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -104,23 +103,52 @@ def popen_launch_server( return process +def terminate_and_wait(process, timeout=300): + """Terminate a process and wait until it is terminated. + + Args: + process: subprocess.Popen object + timeout: maximum time to wait in seconds + + Raises: + TimeoutError: if process does not terminate within timeout + """ + if process is None: + return + + process.terminate() + start_time = time.time() + + while process.poll() is None: + print(f"Terminating process {process.pid}") + if time.time() - start_time > timeout: + raise TimeoutError( + f"Process {process.pid} failed to terminate within {timeout}s" + ) + time.sleep(1) + + print(f"Process {process.pid} is successfully terminated") + + class TestLaunchServer(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = None - cls.other_process = [] + def setUp(self): + self.model = DEFAULT_MODEL_NAME_FOR_TEST + self.base_url = DEFAULT_URL_FOR_TEST + self.process = None + self.other_process = [] - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - for process in cls.other_process: - kill_process_tree(process.pid) + def tearDown(self): + print("Running tearDown...") + if self.process: + terminate_and_wait(self.process) + for process in self.other_process: + terminate_and_wait(process) + print("tearDown done") - def test_mmlu(self): + def test_1_mmlu(self): + print("Running test_1_mmlu...") # DP size = 2 - TestLaunchServer.process = popen_launch_router( + self.process = popen_launch_router( self.model, self.base_url, dp_size=2, @@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) - def test_add_and_remove_worker(self): + def test_2_add_and_remove_worker(self): + print("Running test_2_add_and_remove_worker...") # DP size = 1 - TestLaunchServer.process = popen_launch_router( + self.process = popen_launch_router( self.model, self.base_url, dp_size=1, @@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase): worker_process = popen_launch_server( self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) - TestLaunchServer.other_process.append(worker_process) + self.other_process.append(worker_process) # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy with requests.Session() as session: