74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import collections
|
|
import concurrent.futures
|
|
import uuid
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_cache_aware_affinity(mock_workers, router_manager):
|
|
# Two workers; same prompt should stick to one due to cache tree
|
|
_, urls, ids = mock_workers(n=2)
|
|
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
|
|
|
|
counts = collections.Counter()
|
|
with requests.Session() as s:
|
|
for i in range(12):
|
|
r = s.post(
|
|
f"{rh.url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": "repeated prompt for cache",
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert r.status_code == 200
|
|
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
|
counts[wid] += 1
|
|
|
|
# Expect strong skew toward one worker (tree match); majority > 80%
|
|
top = max(counts.values())
|
|
assert top >= 10, counts
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
|
|
# Add latency so concurrent requests overlap and influence load-based selection
|
|
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
|
|
rh = router_manager.start_router(
|
|
worker_urls=urls,
|
|
policy="cache_aware",
|
|
extra={
|
|
"cache_threshold": 0.99,
|
|
"balance_abs_threshold": 0,
|
|
"balance_rel_threshold": 1.0,
|
|
},
|
|
)
|
|
|
|
counts = collections.Counter()
|
|
|
|
def call(i):
|
|
# Use diverse, unrelated prompts to avoid prefix matches entirely
|
|
prompt = str(uuid.uuid4())
|
|
r = requests.post(
|
|
f"{rh.url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": prompt,
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
timeout=5,
|
|
)
|
|
assert r.status_code == 200
|
|
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
|
|
for wid in ex.map(call, range(40)):
|
|
counts[wid] += 1
|
|
|
|
# Expect participation of at least two workers
|
|
assert sum(1 for v in counts.values() if v > 0) >= 2, counts
|