[feature] [sgl-router] Add a dp-aware routing strategy (#6869)

This commit is contained in:
Rui Chen
2025-07-30 20:58:48 +08:00
committed by GitHub
parent 55ecdc0a8e
commit a730ce8162
19 changed files with 726 additions and 16 deletions

View File

@@ -30,6 +30,7 @@ def popen_launch_router(
service_discovery_namespace: str = None,
prometheus_port: int = None,
prometheus_host: str = None,
dp_aware: bool = False,
):
"""
Launch the router server process.
@@ -49,6 +50,7 @@ def popen_launch_router(
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
prometheus_host: Host address to bind the Prometheus metrics server.
dp_aware: Enable data parallelism aware routing strategy.
"""
_, host, port = base_url.split(":")
host = host[2:]
@@ -69,10 +71,12 @@ def popen_launch_router(
"5",
"--router-policy",
policy,
"--allow-auto-truncate",
]
if api_key is not None:
command.extend(["--api-key", api_key])
command.extend(["--router-api-key", api_key])
if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])
@@ -100,6 +104,9 @@ def popen_launch_router(
if log_dir is not None:
command.extend(["--log-dir", log_dir])
if dp_aware:
command.append("--router-dp-aware")
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.perf_counter()
@@ -127,6 +134,7 @@ def popen_launch_server(
model: str,
base_url: str,
timeout: float,
api_key: str = None,
):
_, host, port = base_url.split(":")
host = host[2:]
@@ -145,6 +153,9 @@ def popen_launch_server(
"1",
]
if api_key is not None:
command.extend(["--api-key", api_key])
process = subprocess.Popen(command, stdout=None, stderr=None)
# intentionally don't wait and defer the job to the router health check
@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
response.status_code, 200, "Request with correct api key should succeed"
)
def test_6_mmlu_with_dp_aware(self):
print("Running test_6_mmlu_with_dp_aware...")
# DP size = 2
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=2,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="cache_aware",
dp_aware=True,
)
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_7_add_and_remove_worker_with_dp_aware(self):
print("Running test_7_add_and_remove_worker_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin", # make sure every worker processes requests
dp_aware=True, # dp aware strategy should work well with RR
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 4. Use the /remove_worker API to remove it from the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 5. Run mmlu again
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 6. Start another worker with api_key set
terminate_and_wait(worker_process) # terminate the old worker process
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model,
worker_url,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key="correct_api_key",
)
self.other_process.append(worker_process)
# 7. Use the /add_worker API to add it to the router
# Should fail since the router would contact the worker's
# /get_server_info endpoint for the dp_size info, but it
# has no knowledge of the api key
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertNotEqual(response.status_code, 200)
def test_8_lazy_fault_tolerance_with_dp_aware(self):
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
dp_aware=True,
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# Start a thread to kill the worker after 10 seconds to mimic
# abrupt worker failure
def kill_worker():
time.sleep(10)
kill_process_tree(worker_process.pid)
print("Worker process killed")
import threading
kill_thread = threading.Thread(target=kill_worker)
kill_thread.daemon = True
kill_thread.start()
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=256,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_9_payload_size_with_dp_aware(self):
print("Running test_9_payload_size_with_dp_aware...")
# Start the router with 1MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
dp_aware=True,
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)
def test_10_api_key_with_dp_aware(self):
print("Running test_10_api_key_with_dp_aware...")
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
api_key="correct_api_key",
dp_aware=True,
)
# Test case 1: request without api key should fail
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 2: request with invalid api key should fail
with requests.Session() as session:
response = requests.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
headers={"Authorization": "Bearer 123"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 3: request with correct api key should succeed
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is ", "temperature": 0},
headers={"Authorization": "Bearer correct_api_key"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
200,
f"Request with correct api key should succeed but got status {response.status_code}",
)
if __name__ == "__main__":
unittest.main()