[router] Introduce router integration tests (#10086)
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Load balancing integration tests."""
|
||||
@@ -0,0 +1,73 @@
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cache_aware_affinity(mock_workers, router_manager):
|
||||
# Two workers; same prompt should stick to one due to cache tree
|
||||
_, urls, ids = mock_workers(n=2)
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
|
||||
|
||||
counts = collections.Counter()
|
||||
with requests.Session() as s:
|
||||
for i in range(12):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "repeated prompt for cache",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect strong skew toward one worker (tree match); majority > 80%
|
||||
top = max(counts.values())
|
||||
assert top >= 10, counts
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
|
||||
# Add latency so concurrent requests overlap and influence load-based selection
|
||||
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="cache_aware",
|
||||
extra={
|
||||
"cache_threshold": 0.99,
|
||||
"balance_abs_threshold": 0,
|
||||
"balance_rel_threshold": 1.0,
|
||||
},
|
||||
)
|
||||
|
||||
counts = collections.Counter()
|
||||
|
||||
def call(i):
|
||||
# Use diverse, unrelated prompts to avoid prefix matches entirely
|
||||
prompt = str(uuid.uuid4())
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
|
||||
for wid in ex.map(call, range(40)):
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect participation of at least two workers
|
||||
assert sum(1 for v in counts.values() if v > 0) >= 2, counts
|
||||
@@ -0,0 +1,89 @@
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
|
||||
# Start two workers: one slow (higher inflight), one fast
|
||||
# Router monitors /get_load and Power-of-Two uses cached loads to choose
|
||||
# Start one slow and one fast worker using the fixture factory
|
||||
procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"])
|
||||
procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"])
|
||||
procs = procs_slow + procs_fast
|
||||
urls = urls_slow + urls_fast
|
||||
ids = ids_slow + ids_fast
|
||||
slow_id = ids_slow[0]
|
||||
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="power_of_two",
|
||||
extra={"worker_startup_check_interval": 1},
|
||||
)
|
||||
|
||||
# Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick
|
||||
|
||||
def _prime_call(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"warm-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_prime_call, range(128)))
|
||||
time.sleep(2)
|
||||
|
||||
# Apply direct background load on the slow worker to amplify load diff
|
||||
def _direct_load(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{slow_url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"bg-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_direct_load, range(128)))
|
||||
time.sleep(1)
|
||||
|
||||
def call(i):
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
|
||||
counts = collections.Counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
for wid in ex.map(call, range(200)):
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect the slow worker (higher latency/inflight) to receive fewer requests
|
||||
fast_worker_id = [i for i in ids if i != slow_id][0]
|
||||
assert counts[slow_id] < counts[fast_worker_id], counts
|
||||
33
sgl-router/py_test/integration/load_balancing/test_random.py
Normal file
33
sgl-router/py_test/integration/load_balancing/test_random.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import collections
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_random_distribution(mock_workers, router_manager):
|
||||
procs, urls, ids = mock_workers(n=4)
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="random")
|
||||
|
||||
counts = collections.Counter()
|
||||
N = 200
|
||||
with requests.Session() as s:
|
||||
for i in range(N):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
counts[wid] += 1
|
||||
|
||||
# simple statistical tolerance: each worker should be within ±50% of mean
|
||||
mean = N / len(ids)
|
||||
for wid in ids:
|
||||
assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts
|
||||
@@ -0,0 +1,34 @@
|
||||
import collections
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_round_robin_distribution(mock_workers, router_manager):
|
||||
procs, urls, ids = mock_workers(n=3)
|
||||
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="round_robin")
|
||||
|
||||
counts = collections.Counter()
|
||||
with requests.Session() as s:
|
||||
for i in range(30):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"hello {i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid in ids
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect near-even distribution across 3 workers
|
||||
# 30 requests -> ideally 10 each; allow small tolerance ±3
|
||||
for wid in ids:
|
||||
assert 7 <= counts[wid] <= 13, counts
|
||||
Reference in New Issue
Block a user