[router] create worker removal step and clean up worker manager (#11921)

This commit is contained in:
Simo Lin
2025-10-22 13:26:06 -07:00
committed by GitHub
parent eec9e471ca
commit 5dccf69713
23 changed files with 758 additions and 1905 deletions

View File

@@ -85,6 +85,8 @@ def _popen_launch_router(
str(prom_port),
"--router-prometheus-host",
"127.0.0.1",
"--router-log-level",
"warn",
]
proc = subprocess.Popen(cmd)

View File

@@ -1,9 +1,31 @@
import time
from types import SimpleNamespace
import pytest
import requests
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e
def test_embeddings_basic(
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
@@ -12,8 +34,11 @@ def test_embeddings_basic(
worker_url = e2e_primary_embedding_worker.url
# Attach embedding worker to router-only instance
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
# Simple embedding request with two inputs
payload = {

View File

@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
"--policy",
"round_robin",
"--pd-disaggregation",
"--log-level",
"warn",
]
for url, bport in prefill:
cmd += ["--prefill", url, str(bport)]

View File

@@ -8,13 +8,39 @@ import requests
from sglang.test.run_eval import run_eval
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
args = SimpleNamespace(
base_url=base,
@@ -35,8 +61,13 @@ def test_genai_bench(
"""Attach a worker to the regular router and run a short genai-bench."""
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
genai_bench_runner(
router_url=base,
@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
base = e2e_router_only_rr.url
worker_url = e2e_primary_worker.url
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
with requests.Session() as s:
for i in range(8):
@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
r.raise_for_status()
# Remove the worker
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
r.raise_for_status()
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
@pytest.mark.e2e
@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
base = e2e_router_only_rr.url
worker = e2e_primary_worker
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
def killer():
time.sleep(10)
@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers
r = requests.post(
f"{router_url}/add_worker",
params={"url": worker_url, "api_key": api_key},
f"{router_url}/workers",
json={"url": worker_url, "api_key": api_key},
headers={"Authorization": f"Bearer {api_key}"},
timeout=180,
)
r.raise_for_status()
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered and expanded
_wait_for_workers(
router_url,
expected_count=2,
timeout=60.0,
headers={"Authorization": f"Bearer {api_key}"},
)
# Verify the expanded workers have correct URLs
r = requests.get(
f"{router_url}/list_workers",
f"{router_url}/workers",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30,
)
r.raise_for_status()
urls = r.json().get("urls", [])
workers = r.json().get("workers", [])
urls = [w["url"] for w in workers]
assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}

View File

@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
policy,
"--model-path",
model,
"--log-level",
"warn",
]
# Add worker URLs

View File

@@ -133,19 +133,90 @@ class RouterManager:
time.sleep(0.2)
raise TimeoutError(f"Router at {base_url} did not become healthy")
def add_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None:
r = requests.post(f"{base_url}/workers", json={"url": worker_url})
assert (
r.status_code == 202
), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status
def remove_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
# Poll until worker is actually added and healthy
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 200:
data = r.json()
# Check if registration job failed
job_status = data.get("job_status")
if job_status and job_status.get("state") == "failed":
raise RuntimeError(
f"Worker registration failed: {job_status.get('message', 'Unknown error')}"
)
# Check if worker is healthy and registered (not just in job queue)
if data.get("is_healthy", False):
return
# Worker not ready yet, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
raise TimeoutError(
f"Worker {worker_url} was not added and healthy after {timeout}s"
)
def remove_worker(
self, base_url: str, worker_url: str, timeout: float = 30.0
) -> None:
# URL encode the worker_url for path parameter
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base_url}/workers/{encoded_url}")
assert (
r.status_code == 202
), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status
# Poll until worker is actually removed (GET returns 404) or timeout
start = time.time()
last_status = None
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 404:
# Worker successfully removed
return
elif r.status_code == 200:
# Check if removal job failed
data = r.json()
job_status = data.get("job_status")
if job_status:
last_status = job_status
if job_status.get("state") == "failed":
raise RuntimeError(
f"Worker removal failed: {job_status.get('message', 'Unknown error')}"
)
# Worker still being processed, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
# Provide detailed timeout error with last known status
error_msg = f"Worker {worker_url} was not removed after {timeout}s"
if last_status:
error_msg += f". Last job status: {last_status}"
raise TimeoutError(error_msg)
def list_workers(self, base_url: str) -> list[str]:
r = requests.get(f"{base_url}/list_workers")
r = requests.get(f"{base_url}/workers")
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
data = r.json()
return data.get("urls", [])
# Extract URLs from WorkerInfo objects
workers = data.get("workers", [])
return [w["url"] for w in workers]
def stop_all(self):
for p in self._children:

View File

@@ -2,7 +2,7 @@ import os
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import pytest
import requests
@@ -84,7 +84,7 @@ def mock_workers():
procs: List[subprocess.Popen] = []
def _start(n: int, args: List[str] | None = None):
def _start(n: int, args: Optional[List[str]] = None):
args = args or []
new_procs: List[subprocess.Popen] = []
urls: List[str] = []