Files
sglang/test/srt/hicache/test_disaggregation_hicache.py

272 lines
8.3 KiB
Python
Raw Normal View History

import os
import random
import tempfile
import time
import unittest
from typing import Dict
from urllib.parse import urlparse
import requests
from sglang.bench_serving import get_tokenizer
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server,
)
class DisaggregationHiCacheBase(TestDisaggregationBase):
"""Base class for disaggregation with HiCache tests"""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
cls.tokenizer = get_tokenizer(cls.model)
cls.temp_dir = tempfile.mkdtemp()
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
# Prefill with HiCache enabled
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp-size",
"1",
"--page-size",
"64",
"--enable-hierarchical-cache",
"--hicache-ratio",
"1.2",
"--hicache-size",
"0",
"--hicache-write-policy",
"write_through",
"--hicache-storage-backend",
"file",
"--hicache-storage-prefetch-policy",
"wait_complete",
"--mem-fraction-static",
"0.8",
"--disaggregation-ib-device",
"mlx5_roce0",
"--disaggregation-transfer-backend",
"mooncake",
]
env = {
**os.environ,
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
}
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
env=env,
)
@classmethod
def start_decode(cls):
pass
def gen_prompt(self, token_num: int) -> str:
all_available_tokens = list(self.tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return self.tokenizer.decode(selected_tokens)
def send_request(
self, prompt: str, max_tokens: int = 100, temperature: float = 0.0
) -> Dict:
"""Send a generate request and return response"""
response = requests.post(
f"{self.lb_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"ignore_eos": True,
},
},
timeout=60,
)
self.assertEqual(
response.status_code,
200,
f"Request failed: {response.status_code} - {response.text}",
)
return response.json()
def trigger_offloading_and_flush(self):
"""Helper method to trigger offloading and flush cache"""
# Trigger offloading
self.send_request(self.gen_prompt(1), max_tokens=150)
# Flush device cache to force remote storage access
time.sleep(2)
requests.post(self.prefill_url + "/flush_cache")
class TestDisaggregationPrefillWithHiCache(DisaggregationHiCacheBase):
"""Test disaggregation with HiCache enabled only on Prefill side"""
@classmethod
def start_decode(cls):
# Decode without HiCache offload
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp-size",
"1",
"--page-size",
"64",
"--mem-fraction-static",
"0.8",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
"--disaggregation-transfer-backend",
"mooncake",
]
env = {
**os.environ,
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
}
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
env=env,
)
def test_prefill_cache_hit(self):
"""Test that prefill cache works with repeated queries"""
repeated_prompt = self.gen_prompt(800)
# First request - should miss cache
self.send_request(repeated_prompt, max_tokens=100)
# Flush cache
self.trigger_offloading_and_flush()
# Second request - should hit cache (faster)
response2 = self.send_request(repeated_prompt, max_tokens=100)
# Assert cached tokens cnt
self.assertGreater(response2["meta_info"]["cached_tokens"], 700)
class TestDisaggregationDecodeWithHiCache(DisaggregationHiCacheBase):
"""Test disaggregation with HiCache enabled on both Prefill and Decode sides"""
@classmethod
def start_decode(cls):
# Decode with HiCache offload enabled
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp-size",
"1",
"--page-size",
"64",
"--mem-fraction-static",
"0.8",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
"--disaggregation-transfer-backend",
"mooncake",
"--disaggregation-decode-enable-offload-kvcache",
"--hicache-ratio",
"1.2",
"--hicache-size",
"0",
"--hicache-storage-backend",
"file",
"--hicache-storage-prefetch-policy",
"wait_complete",
]
env = {
**os.environ,
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
}
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
env=env,
)
def test_multi_turn_conversation_cache(self):
"""Test multi-turn conversation scenario with cache hit improvement"""
print("=== Multi-turn Conversation Cache Test ===")
# Turn 1
initial_prompt = self.gen_prompt(300)
response1 = self.send_request(initial_prompt, max_tokens=200, temperature=0.1)
current_context = initial_prompt + response1["text"]
# Turns 2-4: Continue generation based on previous context
previous_cached_tokens = 0
for turn in range(2, 5):
print(f"\nTurn {turn}: Continuing from previous context")
response = self.send_request(
current_context, max_tokens=200, temperature=0.1
)
cached_tokens = response["meta_info"]["cached_tokens"]
print(f"Turn {turn} cached tokens: {cached_tokens}")
print(f"Improvement: {cached_tokens - previous_cached_tokens} tokens")
# Assert cache improvement
self.assertGreater(
cached_tokens,
previous_cached_tokens,
f"Turn {turn} should have more cached tokens than turn {turn-1}",
)
# Update context and cached tokens for next iteration
current_context += response["text"]
previous_cached_tokens = cached_tokens
# Flush prefill cache
self.trigger_offloading_and_flush()
if __name__ == "__main__":
unittest.main()