[router] Add PD router mmlu test (#10256)
This commit is contained in:
232
sgl-router/py_test/e2e/test_pd_router.py
Normal file
232
sgl-router/py_test/e2e/test_pd_router.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
|
||||
def _find_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_health(url: str, timeout: float = 180.0) -> None:
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{url}/health", timeout=5)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
raise TimeoutError(f"Service at {url} failed to become healthy in time")
|
||||
|
||||
|
||||
def _detect_ib_device() -> Optional[str]:
|
||||
"""Return first active IB device name (e.g., mlx5_0) or None if unavailable."""
|
||||
# Fast check that ibv_devinfo exists
|
||||
try:
|
||||
subprocess.run(
|
||||
["ibv_devinfo", "-l"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=1,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return None
|
||||
|
||||
for i in range(12):
|
||||
dev = f"mlx5_{i}"
|
||||
try:
|
||||
res = subprocess.run(
|
||||
["ibv_devinfo", dev],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
if res.returncode == 0 and ("state:" in res.stdout):
|
||||
for line in res.stdout.splitlines():
|
||||
if "state:" in line and "PORT_ACTIVE" in line:
|
||||
return dev
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _popen_launch_prefill_worker(
|
||||
model: str,
|
||||
bootstrap_port: int,
|
||||
ib_device: Optional[str] = None,
|
||||
base_gpu_id: int = 0,
|
||||
) -> SimpleNamespace:
|
||||
port = _find_available_port()
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--disaggregation-bootstrap-port",
|
||||
str(bootstrap_port),
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
]
|
||||
if ib_device:
|
||||
cmd += ["--disaggregation-ib-device", ib_device]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_health(url, timeout=300.0)
|
||||
return SimpleNamespace(proc=proc, url=url, bootstrap_port=bootstrap_port)
|
||||
|
||||
|
||||
def _popen_launch_decode_worker(
|
||||
model: str, ib_device: Optional[str] = None, base_gpu_id: int = 0
|
||||
) -> SimpleNamespace:
|
||||
port = _find_available_port()
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
]
|
||||
if ib_device:
|
||||
cmd += ["--disaggregation-ib-device", ib_device]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_health(url, timeout=300.0)
|
||||
return SimpleNamespace(proc=proc, url=url)
|
||||
|
||||
|
||||
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
|
||||
if proc is None:
|
||||
return
|
||||
proc.terminate()
|
||||
start = time.perf_counter()
|
||||
while proc.poll() is None:
|
||||
if time.perf_counter() - start > timeout:
|
||||
proc.kill()
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_pd_mmlu(e2e_model: str):
|
||||
"""
|
||||
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run MMLU.
|
||||
"""
|
||||
# Environment capability checks: require sgl_kernel and GPU backend
|
||||
try:
|
||||
import sgl_kernel # noqa: F401
|
||||
except Exception as e: # pragma: no cover - environment dependent
|
||||
pytest.fail(f"PD e2e requires sgl_kernel but it is not available: {e}")
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except Exception as e: # pragma: no cover - environment dependent
|
||||
pytest.fail(
|
||||
f"PD e2e requires torch but it is not available or misconfigured: {e}"
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available(): # pragma: no cover - environment dependent
|
||||
pytest.fail("PD e2e requires CUDA backend, but CUDA is not available")
|
||||
|
||||
# Start two prefill workers (with bootstrap ports) and two decode workers
|
||||
workers: list[SimpleNamespace] = []
|
||||
try:
|
||||
ib_device = _detect_ib_device()
|
||||
|
||||
# Launch 4 workers across 4 GPUs: prefill on 0,1 and decode on 2,3
|
||||
pf1 = _popen_launch_prefill_worker(
|
||||
e2e_model,
|
||||
bootstrap_port=_find_available_port(),
|
||||
ib_device=ib_device,
|
||||
base_gpu_id=0,
|
||||
)
|
||||
pf2 = _popen_launch_prefill_worker(
|
||||
e2e_model,
|
||||
bootstrap_port=_find_available_port(),
|
||||
ib_device=ib_device,
|
||||
base_gpu_id=1,
|
||||
)
|
||||
dc1 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=2)
|
||||
dc2 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=3)
|
||||
prefills = [pf1, pf2]
|
||||
decodes = [dc1, dc2]
|
||||
workers.extend(prefills + decodes)
|
||||
|
||||
# PD router with two prefill and two decode endpoints
|
||||
rport = _find_available_port()
|
||||
router_url = f"http://127.0.0.1:{rport}"
|
||||
pport = _find_available_port()
|
||||
|
||||
prefill = [(pf.url, pf.bootstrap_port) for pf in prefills]
|
||||
decode = [dc.url for dc in decodes]
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_router",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(rport),
|
||||
"--policy",
|
||||
"round_robin",
|
||||
"--pd-disaggregation",
|
||||
# prefill URLs (explicitly pass 'none' for bootstrap port)
|
||||
]
|
||||
for url, bport in prefill:
|
||||
cmd += ["--prefill", url, str(bport)]
|
||||
for url in decode:
|
||||
cmd += ["--decode", url]
|
||||
cmd += [
|
||||
# prometheus (avoid collisions across tests)
|
||||
"--prometheus-port",
|
||||
str(pport),
|
||||
"--prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
|
||||
router_proc = subprocess.Popen(cmd)
|
||||
try:
|
||||
_wait_health(router_url, timeout=180.0)
|
||||
|
||||
# Run a modest MMLU eval through the PD router
|
||||
args = SimpleNamespace(
|
||||
base_url=router_url,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
_terminate(router_proc)
|
||||
finally:
|
||||
for w in workers:
|
||||
_terminate(w.proc)
|
||||
Reference in New Issue
Block a user