Files
sglang/sgl-router/py_test/integration/load_balancing/test_cache_aware.py

74 lines
2.3 KiB
Python

import collections
import concurrent.futures
import uuid
import pytest
import requests
@pytest.mark.integration
def test_cache_aware_affinity(mock_workers, router_manager):
# Two workers; same prompt should stick to one due to cache tree
_, urls, ids = mock_workers(n=2)
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
counts = collections.Counter()
with requests.Session() as s:
for i in range(12):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "repeated prompt for cache",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# Expect strong skew toward one worker (tree match); majority > 80%
top = max(counts.values())
assert top >= 10, counts
@pytest.mark.integration
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
# Add latency so concurrent requests overlap and influence load-based selection
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
rh = router_manager.start_router(
worker_urls=urls,
policy="cache_aware",
extra={
"cache_threshold": 0.99,
"balance_abs_threshold": 0,
"balance_rel_threshold": 1.0,
},
)
counts = collections.Counter()
def call(i):
# Use diverse, unrelated prompts to avoid prefix matches entirely
prompt = str(uuid.uuid4())
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": prompt,
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
for wid in ex.map(call, range(40)):
counts[wid] += 1
# Expect participation of at least two workers
assert sum(1 for v in counts.values() if v > 0) >= 2, counts