[router] Introduce router integration tests (#10086)

This commit is contained in:
Keyang Ru
2025-09-05 18:52:53 -07:00
committed by GitHub
parent db37422c92
commit 21b9a4b435
23 changed files with 1417 additions and 2 deletions

View File

@@ -0,0 +1 @@
"""Load balancing integration tests."""

View File

@@ -0,0 +1,73 @@
import collections
import concurrent.futures
import uuid
import pytest
import requests
@pytest.mark.integration
def test_cache_aware_affinity(mock_workers, router_manager):
# Two workers; same prompt should stick to one due to cache tree
_, urls, ids = mock_workers(n=2)
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
counts = collections.Counter()
with requests.Session() as s:
for i in range(12):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "repeated prompt for cache",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# Expect strong skew toward one worker (tree match); majority > 80%
top = max(counts.values())
assert top >= 10, counts
@pytest.mark.integration
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
# Add latency so concurrent requests overlap and influence load-based selection
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
rh = router_manager.start_router(
worker_urls=urls,
policy="cache_aware",
extra={
"cache_threshold": 0.99,
"balance_abs_threshold": 0,
"balance_rel_threshold": 1.0,
},
)
counts = collections.Counter()
def call(i):
# Use diverse, unrelated prompts to avoid prefix matches entirely
prompt = str(uuid.uuid4())
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": prompt,
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
for wid in ex.map(call, range(40)):
counts[wid] += 1
# Expect participation of at least two workers
assert sum(1 for v in counts.values() if v > 0) >= 2, counts

View File

@@ -0,0 +1,89 @@
import collections
import concurrent.futures
import time
import pytest
import requests
@pytest.mark.integration
def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
# Start two workers: one slow (higher inflight), one fast
# Router monitors /get_load and Power-of-Two uses cached loads to choose
# Start one slow and one fast worker using the fixture factory
procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"])
procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"])
procs = procs_slow + procs_fast
urls = urls_slow + urls_fast
ids = ids_slow + ids_fast
slow_id = ids_slow[0]
rh = router_manager.start_router(
worker_urls=urls,
policy="power_of_two",
extra={"worker_startup_check_interval": 1},
)
# Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
# Apply direct background load on the slow worker to amplify load diff
def _direct_load(i):
try:
requests.post(
f"{slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
# Expect the slow worker (higher latency/inflight) to receive fewer requests
fast_worker_id = [i for i in ids if i != slow_id][0]
assert counts[slow_id] < counts[fast_worker_id], counts

View File

@@ -0,0 +1,33 @@
import collections
import math
import pytest
import requests
@pytest.mark.integration
def test_random_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=4)
rh = router_manager.start_router(worker_urls=urls, policy="random")
counts = collections.Counter()
N = 200
with requests.Session() as s:
for i in range(N):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# simple statistical tolerance: each worker should be within ±50% of mean
mean = N / len(ids)
for wid in ids:
assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts

View File

@@ -0,0 +1,34 @@
import collections
import time
import pytest
import requests
@pytest.mark.integration
def test_round_robin_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=3)
rh = router_manager.start_router(worker_urls=urls, policy="round_robin")
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"hello {i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in ids
counts[wid] += 1
# Expect near-even distribution across 3 workers
# 30 requests -> ideally 10 each; allow small tolerance ±3
for wid in ids:
assert 7 <= counts[wid] <= 13, counts