[router] Add PD router mmlu test (#10256)
This commit is contained in:
232
sgl-router/py_test/e2e/test_pd_router.py
Normal file
232
sgl-router/py_test/e2e/test_pd_router.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
|
||||||
|
|
||||||
|
def _find_available_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_health(url: str, timeout: float = 180.0) -> None:
|
||||||
|
start = time.perf_counter()
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.perf_counter() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = session.get(f"{url}/health", timeout=5)
|
||||||
|
if r.status_code == 200:
|
||||||
|
return
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
raise TimeoutError(f"Service at {url} failed to become healthy in time")
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_ib_device() -> Optional[str]:
|
||||||
|
"""Return first active IB device name (e.g., mlx5_0) or None if unavailable."""
|
||||||
|
# Fast check that ibv_devinfo exists
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["ibv_devinfo", "-l"],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
timeout=1,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for i in range(12):
|
||||||
|
dev = f"mlx5_{i}"
|
||||||
|
try:
|
||||||
|
res = subprocess.run(
|
||||||
|
["ibv_devinfo", dev],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
if res.returncode == 0 and ("state:" in res.stdout):
|
||||||
|
for line in res.stdout.splitlines():
|
||||||
|
if "state:" in line and "PORT_ACTIVE" in line:
|
||||||
|
return dev
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _popen_launch_prefill_worker(
|
||||||
|
model: str,
|
||||||
|
bootstrap_port: int,
|
||||||
|
ib_device: Optional[str] = None,
|
||||||
|
base_gpu_id: int = 0,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
port = _find_available_port()
|
||||||
|
url = f"http://127.0.0.1:{port}"
|
||||||
|
cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.launch_server",
|
||||||
|
"--model-path",
|
||||||
|
model,
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--disaggregation-bootstrap-port",
|
||||||
|
str(bootstrap_port),
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(base_gpu_id),
|
||||||
|
]
|
||||||
|
if ib_device:
|
||||||
|
cmd += ["--disaggregation-ib-device", ib_device]
|
||||||
|
proc = subprocess.Popen(cmd)
|
||||||
|
_wait_health(url, timeout=300.0)
|
||||||
|
return SimpleNamespace(proc=proc, url=url, bootstrap_port=bootstrap_port)
|
||||||
|
|
||||||
|
|
||||||
|
def _popen_launch_decode_worker(
|
||||||
|
model: str, ib_device: Optional[str] = None, base_gpu_id: int = 0
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
port = _find_available_port()
|
||||||
|
url = f"http://127.0.0.1:{port}"
|
||||||
|
cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.launch_server",
|
||||||
|
"--model-path",
|
||||||
|
model,
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(base_gpu_id),
|
||||||
|
]
|
||||||
|
if ib_device:
|
||||||
|
cmd += ["--disaggregation-ib-device", ib_device]
|
||||||
|
proc = subprocess.Popen(cmd)
|
||||||
|
_wait_health(url, timeout=300.0)
|
||||||
|
return SimpleNamespace(proc=proc, url=url)
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
|
||||||
|
if proc is None:
|
||||||
|
return
|
||||||
|
proc.terminate()
|
||||||
|
start = time.perf_counter()
|
||||||
|
while proc.poll() is None:
|
||||||
|
if time.perf_counter() - start > timeout:
|
||||||
|
proc.kill()
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e
|
||||||
|
def test_pd_mmlu(e2e_model: str):
|
||||||
|
"""
|
||||||
|
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run MMLU.
|
||||||
|
"""
|
||||||
|
# Environment capability checks: require sgl_kernel and GPU backend
|
||||||
|
try:
|
||||||
|
import sgl_kernel # noqa: F401
|
||||||
|
except Exception as e: # pragma: no cover - environment dependent
|
||||||
|
pytest.fail(f"PD e2e requires sgl_kernel but it is not available: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
except Exception as e: # pragma: no cover - environment dependent
|
||||||
|
pytest.fail(
|
||||||
|
f"PD e2e requires torch but it is not available or misconfigured: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not torch.cuda.is_available(): # pragma: no cover - environment dependent
|
||||||
|
pytest.fail("PD e2e requires CUDA backend, but CUDA is not available")
|
||||||
|
|
||||||
|
# Start two prefill workers (with bootstrap ports) and two decode workers
|
||||||
|
workers: list[SimpleNamespace] = []
|
||||||
|
try:
|
||||||
|
ib_device = _detect_ib_device()
|
||||||
|
|
||||||
|
# Launch 4 workers across 4 GPUs: prefill on 0,1 and decode on 2,3
|
||||||
|
pf1 = _popen_launch_prefill_worker(
|
||||||
|
e2e_model,
|
||||||
|
bootstrap_port=_find_available_port(),
|
||||||
|
ib_device=ib_device,
|
||||||
|
base_gpu_id=0,
|
||||||
|
)
|
||||||
|
pf2 = _popen_launch_prefill_worker(
|
||||||
|
e2e_model,
|
||||||
|
bootstrap_port=_find_available_port(),
|
||||||
|
ib_device=ib_device,
|
||||||
|
base_gpu_id=1,
|
||||||
|
)
|
||||||
|
dc1 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=2)
|
||||||
|
dc2 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=3)
|
||||||
|
prefills = [pf1, pf2]
|
||||||
|
decodes = [dc1, dc2]
|
||||||
|
workers.extend(prefills + decodes)
|
||||||
|
|
||||||
|
# PD router with two prefill and two decode endpoints
|
||||||
|
rport = _find_available_port()
|
||||||
|
router_url = f"http://127.0.0.1:{rport}"
|
||||||
|
pport = _find_available_port()
|
||||||
|
|
||||||
|
prefill = [(pf.url, pf.bootstrap_port) for pf in prefills]
|
||||||
|
decode = [dc.url for dc in decodes]
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang_router.launch_router",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(rport),
|
||||||
|
"--policy",
|
||||||
|
"round_robin",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
# prefill URLs (explicitly pass 'none' for bootstrap port)
|
||||||
|
]
|
||||||
|
for url, bport in prefill:
|
||||||
|
cmd += ["--prefill", url, str(bport)]
|
||||||
|
for url in decode:
|
||||||
|
cmd += ["--decode", url]
|
||||||
|
cmd += [
|
||||||
|
# prometheus (avoid collisions across tests)
|
||||||
|
"--prometheus-port",
|
||||||
|
str(pport),
|
||||||
|
"--prometheus-host",
|
||||||
|
"127.0.0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
router_proc = subprocess.Popen(cmd)
|
||||||
|
try:
|
||||||
|
_wait_health(router_url, timeout=180.0)
|
||||||
|
|
||||||
|
# Run a modest MMLU eval through the PD router
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=router_url,
|
||||||
|
model=e2e_model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.65
|
||||||
|
finally:
|
||||||
|
_terminate(router_proc)
|
||||||
|
finally:
|
||||||
|
for w in workers:
|
||||||
|
_terminate(w.proc)
|
||||||
Reference in New Issue
Block a user