[router] create worker removal step and clean up worker manager (#11921)
This commit is contained in:
@@ -85,6 +85,8 @@ def _popen_launch_router(
|
||||
str(prom_port),
|
||||
"--router-prometheus-host",
|
||||
"127.0.0.1",
|
||||
"--router-log-level",
|
||||
"warn",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(cmd)
|
||||
|
||||
@@ -1,9 +1,31 @@
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _wait_for_workers(
|
||||
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
|
||||
) -> None:
|
||||
"""Poll /workers endpoint until expected number of workers are registered."""
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
|
||||
if r.status_code == 200:
|
||||
workers = r.json().get("workers", [])
|
||||
if len(workers) >= expected_count:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_embeddings_basic(
|
||||
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
|
||||
@@ -12,8 +34,11 @@ def test_embeddings_basic(
|
||||
worker_url = e2e_primary_embedding_worker.url
|
||||
|
||||
# Attach embedding worker to router-only instance
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
|
||||
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for worker to be registered
|
||||
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||
|
||||
# Simple embedding request with two inputs
|
||||
payload = {
|
||||
|
||||
@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
|
||||
"--policy",
|
||||
"round_robin",
|
||||
"--pd-disaggregation",
|
||||
"--log-level",
|
||||
"warn",
|
||||
]
|
||||
for url, bport in prefill:
|
||||
cmd += ["--prefill", url, str(bport)]
|
||||
|
||||
@@ -8,13 +8,39 @@ import requests
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
|
||||
def _wait_for_workers(
|
||||
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
|
||||
) -> None:
|
||||
"""Poll /workers endpoint until expected number of workers are registered."""
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
|
||||
if r.status_code == 200:
|
||||
workers = r.json().get("workers", [])
|
||||
if len(workers) >= expected_count:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
|
||||
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
|
||||
base = e2e_router_only_rr.url
|
||||
for w in e2e_two_workers_dp2:
|
||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
|
||||
assert (
|
||||
r.status_code == 202
|
||||
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for workers to be registered
|
||||
_wait_for_workers(base, expected_count=2, timeout=60.0)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=base,
|
||||
@@ -35,8 +61,13 @@ def test_genai_bench(
|
||||
"""Attach a worker to the regular router and run a short genai-bench."""
|
||||
base = e2e_router_only_rr.url
|
||||
for w in e2e_two_workers_dp2:
|
||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
|
||||
assert (
|
||||
r.status_code == 202
|
||||
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for workers to be registered
|
||||
_wait_for_workers(base, expected_count=2, timeout=60.0)
|
||||
|
||||
genai_bench_runner(
|
||||
router_url=base,
|
||||
@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
|
||||
base = e2e_router_only_rr.url
|
||||
worker_url = e2e_primary_worker.url
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
|
||||
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for worker to be registered
|
||||
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||
|
||||
with requests.Session() as s:
|
||||
for i in range(8):
|
||||
@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
|
||||
r.raise_for_status()
|
||||
|
||||
# Remove the worker
|
||||
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
|
||||
r.raise_for_status()
|
||||
from urllib.parse import quote
|
||||
|
||||
encoded_url = quote(worker_url, safe="")
|
||||
r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60)
|
||||
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
|
||||
base = e2e_router_only_rr.url
|
||||
worker = e2e_primary_worker
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180)
|
||||
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for worker to be registered
|
||||
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||
|
||||
def killer():
|
||||
time.sleep(10)
|
||||
@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
|
||||
|
||||
# Attach worker; router should expand to dp_size logical workers
|
||||
r = requests.post(
|
||||
f"{router_url}/add_worker",
|
||||
params={"url": worker_url, "api_key": api_key},
|
||||
f"{router_url}/workers",
|
||||
json={"url": worker_url, "api_key": api_key},
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=180,
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||
|
||||
# Wait for workers to be registered and expanded
|
||||
_wait_for_workers(
|
||||
router_url,
|
||||
expected_count=2,
|
||||
timeout=60.0,
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
# Verify the expanded workers have correct URLs
|
||||
r = requests.get(
|
||||
f"{router_url}/list_workers",
|
||||
f"{router_url}/workers",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=30,
|
||||
)
|
||||
r.raise_for_status()
|
||||
urls = r.json().get("urls", [])
|
||||
workers = r.json().get("workers", [])
|
||||
urls = [w["url"] for w in workers]
|
||||
assert len(urls) == 2
|
||||
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
||||
|
||||
|
||||
@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
|
||||
policy,
|
||||
"--model-path",
|
||||
model,
|
||||
"--log-level",
|
||||
"warn",
|
||||
]
|
||||
|
||||
# Add worker URLs
|
||||
|
||||
@@ -133,19 +133,90 @@ class RouterManager:
|
||||
time.sleep(0.2)
|
||||
raise TimeoutError(f"Router at {base_url} did not become healthy")
|
||||
|
||||
def add_worker(self, base_url: str, worker_url: str) -> None:
|
||||
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
|
||||
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
|
||||
def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None:
|
||||
r = requests.post(f"{base_url}/workers", json={"url": worker_url})
|
||||
assert (
|
||||
r.status_code == 202
|
||||
), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status
|
||||
|
||||
def remove_worker(self, base_url: str, worker_url: str) -> None:
|
||||
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
|
||||
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
|
||||
# Poll until worker is actually added and healthy
|
||||
from urllib.parse import quote
|
||||
|
||||
encoded_url = quote(worker_url, safe="")
|
||||
start = time.time()
|
||||
with requests.Session() as s:
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
# Check if registration job failed
|
||||
job_status = data.get("job_status")
|
||||
if job_status and job_status.get("state") == "failed":
|
||||
raise RuntimeError(
|
||||
f"Worker registration failed: {job_status.get('message', 'Unknown error')}"
|
||||
)
|
||||
# Check if worker is healthy and registered (not just in job queue)
|
||||
if data.get("is_healthy", False):
|
||||
return
|
||||
# Worker not ready yet, continue polling
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
raise TimeoutError(
|
||||
f"Worker {worker_url} was not added and healthy after {timeout}s"
|
||||
)
|
||||
|
||||
def remove_worker(
|
||||
self, base_url: str, worker_url: str, timeout: float = 30.0
|
||||
) -> None:
|
||||
# URL encode the worker_url for path parameter
|
||||
from urllib.parse import quote
|
||||
|
||||
encoded_url = quote(worker_url, safe="")
|
||||
r = requests.delete(f"{base_url}/workers/{encoded_url}")
|
||||
assert (
|
||||
r.status_code == 202
|
||||
), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status
|
||||
|
||||
# Poll until worker is actually removed (GET returns 404) or timeout
|
||||
start = time.time()
|
||||
last_status = None
|
||||
with requests.Session() as s:
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
|
||||
if r.status_code == 404:
|
||||
# Worker successfully removed
|
||||
return
|
||||
elif r.status_code == 200:
|
||||
# Check if removal job failed
|
||||
data = r.json()
|
||||
job_status = data.get("job_status")
|
||||
if job_status:
|
||||
last_status = job_status
|
||||
if job_status.get("state") == "failed":
|
||||
raise RuntimeError(
|
||||
f"Worker removal failed: {job_status.get('message', 'Unknown error')}"
|
||||
)
|
||||
# Worker still being processed, continue polling
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
# Provide detailed timeout error with last known status
|
||||
error_msg = f"Worker {worker_url} was not removed after {timeout}s"
|
||||
if last_status:
|
||||
error_msg += f". Last job status: {last_status}"
|
||||
raise TimeoutError(error_msg)
|
||||
|
||||
def list_workers(self, base_url: str) -> list[str]:
|
||||
r = requests.get(f"{base_url}/list_workers")
|
||||
r = requests.get(f"{base_url}/workers")
|
||||
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
|
||||
data = r.json()
|
||||
return data.get("urls", [])
|
||||
# Extract URLs from WorkerInfo objects
|
||||
workers = data.get("workers", [])
|
||||
return [w["url"] for w in workers]
|
||||
|
||||
def stop_all(self):
|
||||
for p in self._children:
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -84,7 +84,7 @@ def mock_workers():
|
||||
|
||||
procs: List[subprocess.Popen] = []
|
||||
|
||||
def _start(n: int, args: List[str] | None = None):
|
||||
def _start(n: int, args: Optional[List[str]] = None):
|
||||
args = args or []
|
||||
new_procs: List[subprocess.Popen] = []
|
||||
urls: List[str] = []
|
||||
|
||||
Reference in New Issue
Block a user