[router] Improve cleanup logic (#2411)
This commit is contained in:
@@ -10,12 +10,12 @@ import time
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from setproctitle import setproctitle
|
||||||
from sglang_router.launch_router import RouterArgs, launch_router
|
from sglang_router.launch_router import RouterArgs, launch_router
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import is_port_available
|
from sglang.srt.utils import is_port_available
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
def setup_logger():
|
||||||
@@ -34,10 +34,12 @@ def setup_logger():
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
# Create new process group
|
# Create new process group
|
||||||
def run_server(server_args, dp_rank):
|
def run_server(server_args, dp_rank):
|
||||||
os.setpgrp() # Create new process group
|
setproctitle(f"sglang::server")
|
||||||
|
|
||||||
# Set SGLANG_DP_RANK environment variable
|
# Set SGLANG_DP_RANK environment variable
|
||||||
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
||||||
|
|
||||||
@@ -58,36 +60,6 @@ def launch_server_process(
|
|||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
||||||
def cleanup_processes(processes: List[mp.Process]):
|
|
||||||
logger = logging.getLogger("router")
|
|
||||||
logger.info("Cleaning up processes...")
|
|
||||||
for proc in processes:
|
|
||||||
if proc.is_alive():
|
|
||||||
try:
|
|
||||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
||||||
proc.join(timeout=3)
|
|
||||||
if proc.is_alive():
|
|
||||||
logger.warning(
|
|
||||||
f"Process {proc.pid} did not terminate gracefully, force killing..."
|
|
||||||
)
|
|
||||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
|
||||||
except ProcessLookupError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def setup_signal_handlers(cleanup_func):
|
|
||||||
"""Setup handlers for various termination signals."""
|
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
|
||||||
cleanup_func()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
if hasattr(signal, "SIGQUIT"):
|
|
||||||
signal.signal(signal.SIGQUIT, signal_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
||||||
"""Wait for server to be healthy by checking /health endpoint."""
|
"""Wait for server to be healthy by checking /health endpoint."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
|
|||||||
return available_ports
|
return available_ports
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_processes(processes: List[mp.Process]):
|
||||||
|
for process in processes:
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
@@ -148,52 +124,33 @@ def main():
|
|||||||
# Start server processes
|
# Start server processes
|
||||||
server_processes = []
|
server_processes = []
|
||||||
|
|
||||||
try:
|
for i, worker_port in enumerate(worker_ports):
|
||||||
for i, worker_port in enumerate(worker_ports):
|
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
||||||
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
proc = launch_server_process(server_args, worker_port, i)
|
||||||
proc = launch_server_process(server_args, worker_port, i)
|
server_processes.append(proc)
|
||||||
server_processes.append(proc)
|
|
||||||
|
|
||||||
# Setup cleanup handler
|
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
||||||
setup_signal_handlers(lambda: cleanup_processes(server_processes))
|
signal.signal(
|
||||||
|
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
||||||
|
)
|
||||||
|
signal.signal(
|
||||||
|
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for all servers to be healthy
|
for port in worker_ports:
|
||||||
all_healthy = True
|
if not wait_for_server_health(server_args.host, port):
|
||||||
|
logger.error(f"Server on port {port} failed to become healthy")
|
||||||
|
break
|
||||||
|
|
||||||
for port in worker_ports:
|
logger.info("All servers are healthy. Starting router...")
|
||||||
if not wait_for_server_health(server_args.host, port):
|
|
||||||
logger.error(f"Server on port {port} failed to become healthy")
|
|
||||||
all_healthy = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not all_healthy:
|
# Update router args with worker URLs
|
||||||
logger.error("Not all servers are healthy. Shutting down...")
|
router_args.worker_urls = [
|
||||||
cleanup_processes(server_processes)
|
f"http://{server_args.host}:{port}" for port in worker_ports
|
||||||
sys.exit(1)
|
]
|
||||||
|
|
||||||
logger.info("All servers are healthy. Starting router...")
|
# Start the router
|
||||||
|
router = launch_router(router_args)
|
||||||
# Update router args with worker URLs
|
|
||||||
router_args.worker_urls = [
|
|
||||||
f"http://{server_args.host}:{port}" for port in worker_ports
|
|
||||||
]
|
|
||||||
|
|
||||||
# Start the router
|
|
||||||
router = launch_router(router_args)
|
|
||||||
|
|
||||||
if router is None:
|
|
||||||
logger.error("Failed to start router. Shutting down...")
|
|
||||||
cleanup_processes(server_processes)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Received shutdown signal...")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error occurred: {e}")
|
|
||||||
logger.error(get_exception_traceback())
|
|
||||||
finally:
|
|
||||||
logger.info("Cleaning up processes...")
|
|
||||||
cleanup_processes(server_processes)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
@@ -104,23 +103,52 @@ def popen_launch_server(
|
|||||||
return process
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
def terminate_and_wait(process, timeout=300):
|
||||||
|
"""Terminate a process and wait until it is terminated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process: subprocess.Popen object
|
||||||
|
timeout: maximum time to wait in seconds
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: if process does not terminate within timeout
|
||||||
|
"""
|
||||||
|
if process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
process.terminate()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while process.poll() is None:
|
||||||
|
print(f"Terminating process {process.pid}")
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Process {process.pid} failed to terminate within {timeout}s"
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print(f"Process {process.pid} is successfully terminated")
|
||||||
|
|
||||||
|
|
||||||
class TestLaunchServer(unittest.TestCase):
|
class TestLaunchServer(unittest.TestCase):
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def setUpClass(cls):
|
self.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
self.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
self.process = None
|
||||||
cls.process = None
|
self.other_process = []
|
||||||
cls.other_process = []
|
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self):
|
||||||
def tearDownClass(cls):
|
print("Running tearDown...")
|
||||||
kill_process_tree(cls.process.pid)
|
if self.process:
|
||||||
for process in cls.other_process:
|
terminate_and_wait(self.process)
|
||||||
kill_process_tree(process.pid)
|
for process in self.other_process:
|
||||||
|
terminate_and_wait(process)
|
||||||
|
print("tearDown done")
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_1_mmlu(self):
|
||||||
|
print("Running test_1_mmlu...")
|
||||||
# DP size = 2
|
# DP size = 2
|
||||||
TestLaunchServer.process = popen_launch_router(
|
self.process = popen_launch_router(
|
||||||
self.model,
|
self.model,
|
||||||
self.base_url,
|
self.base_url,
|
||||||
dp_size=2,
|
dp_size=2,
|
||||||
@@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase):
|
|||||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||||
|
|
||||||
def test_add_and_remove_worker(self):
|
def test_2_add_and_remove_worker(self):
|
||||||
|
print("Running test_2_add_and_remove_worker...")
|
||||||
# DP size = 1
|
# DP size = 1
|
||||||
TestLaunchServer.process = popen_launch_router(
|
self.process = popen_launch_router(
|
||||||
self.model,
|
self.model,
|
||||||
self.base_url,
|
self.base_url,
|
||||||
dp_size=1,
|
dp_size=1,
|
||||||
@@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase):
|
|||||||
worker_process = popen_launch_server(
|
worker_process = popen_launch_server(
|
||||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||||
)
|
)
|
||||||
TestLaunchServer.other_process.append(worker_process)
|
self.other_process.append(worker_process)
|
||||||
|
|
||||||
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
|
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
|
||||||
with requests.Session() as session:
|
with requests.Session() as session:
|
||||||
|
|||||||
Reference in New Issue
Block a user