[feature] [sgl-router] Add a dp-aware routing strategy (#6869)
This commit is contained in:
@@ -30,6 +30,7 @@ def popen_launch_router(
|
||||
service_discovery_namespace: str = None,
|
||||
prometheus_port: int = None,
|
||||
prometheus_host: str = None,
|
||||
dp_aware: bool = False,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
@@ -49,6 +50,7 @@ def popen_launch_router(
|
||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
|
||||
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
|
||||
prometheus_host: Host address to bind the Prometheus metrics server.
|
||||
dp_aware: Enable data parallelism aware routing strategy.
|
||||
"""
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -69,10 +71,12 @@ def popen_launch_router(
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
"--allow-auto-truncate",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
command.extend(["--router-api-key", api_key])
|
||||
|
||||
if max_payload_size is not None:
|
||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||
@@ -100,6 +104,9 @@ def popen_launch_router(
|
||||
if log_dir is not None:
|
||||
command.extend(["--log-dir", log_dir])
|
||||
|
||||
if dp_aware:
|
||||
command.append("--router-dp-aware")
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@@ -127,6 +134,7 @@ def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
timeout: float,
|
||||
api_key: str = None,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -145,6 +153,9 @@ def popen_launch_server(
|
||||
"1",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
# intentionally don't wait and defer the job to the router health check
|
||||
@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
|
||||
response.status_code, 200, "Request with correct api key should succeed"
|
||||
)
|
||||
|
||||
def test_6_mmlu_with_dp_aware(self):
|
||||
print("Running test_6_mmlu_with_dp_aware...")
|
||||
# DP size = 2
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=2,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="cache_aware",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_7_add_and_remove_worker_with_dp_aware(self):
|
||||
print("Running test_7_add_and_remove_worker_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin", # make sure every worker processes requests
|
||||
dp_aware=True, # dp aware strategy should work well with RR
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 4. Use the /remove_worker API to remove it from the router
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 5. Run mmlu again
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 6. Start another worker with api_key set
|
||||
terminate_and_wait(worker_process) # terminate the old worker process
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model,
|
||||
worker_url,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key="correct_api_key",
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 7. Use the /add_worker API to add it to the router
|
||||
# Should fail since the router would contact the worker's
|
||||
# /get_server_info endpoint for the dp_size info, but it
|
||||
# has no knowledge of the api key
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertNotEqual(response.status_code, 200)
|
||||
|
||||
def test_8_lazy_fault_tolerance_with_dp_aware(self):
|
||||
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start a thread to kill the worker after 10 seconds to mimic
|
||||
# abrupt worker failure
|
||||
def kill_worker():
|
||||
time.sleep(10)
|
||||
kill_process_tree(worker_process.pid)
|
||||
print("Worker process killed")
|
||||
|
||||
import threading
|
||||
|
||||
kill_thread = threading.Thread(target=kill_worker)
|
||||
kill_thread.daemon = True
|
||||
kill_thread.start()
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=256,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_9_payload_size_with_dp_aware(self):
|
||||
print("Running test_9_payload_size_with_dp_aware...")
|
||||
|
||||
# Start the router with 1MB limit
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: Payload just under 1MB should succeed
|
||||
payload_0_5_mb = {
|
||||
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_0_5_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: Payload over 1MB should fail
|
||||
payload_1_plus_mb = {
|
||||
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_1_plus_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
413, # Payload Too Large
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
def test_10_api_key_with_dp_aware(self):
|
||||
print("Running test_10_api_key_with_dp_aware...")
|
||||
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
api_key="correct_api_key",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: request without api key should fail
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: request with invalid api key should fail
|
||||
with requests.Session() as session:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer 123"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 3: request with correct api key should succeed
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"Request with correct api key should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user