From 106c2b31fb8a1a7ebc0ea3c1447a80ae03ef37d3 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Thu, 4 Sep 2025 20:43:46 +0800 Subject: [PATCH] feat(hicache): Add generic hicache ci e2e test and benchmark test (#9846) Co-authored-by: Zhiqiang Xie --- .../sglang/srt/mem_cache/hicache_storage.py | 9 +- .../hicache/test_hicache_storage_benchmark.py | 192 ++++++++++++ test/srt/hicache/test_hicache_storage_e2e.py | 286 ++++++++++++++++++ test/srt/run_suite.py | 2 + 4 files changed, 487 insertions(+), 2 deletions(-) create mode 100644 test/srt/hicache/test_hicache_storage_benchmark.py create mode 100644 test/srt/hicache/test_hicache_storage_e2e.py diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 2487910e1..9112e748d 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -136,13 +136,18 @@ class HiCacheFile(HiCacheStorage): ): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) - tp_rank, tp_size, model_name = ( + tp_rank, tp_size, model_name, is_mla_model = ( storage_config.tp_rank, storage_config.tp_size, storage_config.model_name, + storage_config.is_mla_model, ) model_name = "-".join(model_name.split("/")) if model_name else "" - self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" + if is_mla_model: + self.config_suffix = f"_{model_name}" + else: + self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" + if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") diff --git a/test/srt/hicache/test_hicache_storage_benchmark.py b/test/srt/hicache/test_hicache_storage_benchmark.py new file mode 100644 index 000000000..0c9206afb --- /dev/null +++ b/test/srt/hicache/test_hicache_storage_benchmark.py @@ -0,0 +1,192 @@ +""" +Benchmark tests for HiCache Storage functionality. +Usage: + python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v +""" + +import time +import unittest +from types import SimpleNamespace +from typing import Dict + +import requests +from test_hicache_storage_e2e import HiCacheStorageBaseTest + +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import is_in_ci, write_github_step_summary + + +class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest): + """Benchmark tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = {"--tp-size": 2, "--hicache-ratio": 1.5} + return server_args, {} + + def flush_cache(self) -> bool: + """Flush device cache to force remote storage access""" + try: + response = requests.post(f"{self.base_url}/flush_cache", timeout=10) + return response.status_code == 200 + except requests.RequestException: + return False + + # === Accuracy Tests === + def test_eval_accuracy_with_cache_persistence(self): + """Test eval accuracy with cache persistence across cache flushes""" + print("\n=== Testing Eval Accuracy with Cache Persistence ===") + + # First evaluation - populate cache + print("Phase 1: Running initial GSM8K evaluation to populate cache...") + args_initial = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=400, + max_new_tokens=512, + parallel=32, + host=f"http://{self.base_host}", + port=int(self.base_port), + ) + metrics_initial = run_eval_few_shot_gsm8k(args_initial) + print(f"Evaluation metrics: {metrics_initial}") + self.assertGreater(metrics_initial["accuracy"], 0.60) + + # Flush cache to force remote storage access + print("Phase 2: Flushing device cache...") + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + time.sleep(2) + + # Second evaluation - should use remote cache + print("Phase 3: Running second GSM8K evaluation using remote cache...") + + start_time = time.time() + metrics_cached = run_eval_few_shot_gsm8k(args_initial) + cached_time = time.time() - start_time + + print(f"Cached evaluation completed in {cached_time:.2f}s") + print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}") + print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s") + + # Verify accuracy consistency + accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + print(f"Accuracy difference: {accuracy_diff:.4f}") + + # Assertions + self.assertGreater( + metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable" + ) + self.assertGreater( + metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable" + ) + self.assertLess( + accuracy_diff, 0.05, "Accuracy should be consistent between cache states" + ) + + # Performance should be similar or better with cache + throughput_ratio = ( + metrics_cached["output_throughput"] / metrics_initial["output_throughput"] + ) + print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}") + + if is_in_ci(): + write_github_step_summary( + f"### HiCache Storage Accuracy Test\n" + f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n" + f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n" + f"Accuracy difference: {accuracy_diff:.4f}\n" + f"Throughput ratio: {throughput_ratio:.2f}\n" + ) + + # === Performance Benchmark Tests === + + def test_throughput_benchmark_with_hicache(self): + """Benchmark throughput performance with HiCache enabled""" + print("\n=== Benchmarking Throughput with HiCache ===") + + # throughput test + res1 = self._run_throughput_benchmark( + test_name="hicache_offline_throughput", + num_prompts=200, + request_rate=10, + additional_args=[], + ) + + # Flush cache to force remote storage access + print("Phase 2: Flushing device cache...") + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + time.sleep(2) + + # Second benchmark, should use remote cache + res2 = self._run_throughput_benchmark( + test_name="hicache_online_throughput", + num_prompts=400, + request_rate=10, + additional_args=[], + ) + + if is_in_ci(): + write_github_step_summary( + f"### HiCache Storage FileBackend Benchmark Test\n" + f"First time throughput: {res1['input_throughput']:.2f} token/s\n" + f"Second time throughput: {res2['input_throughput']:.2f} token/s\n" + f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n" + f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n" + ) + + def _run_throughput_benchmark( + self, + test_name: str, + num_prompts: int, + request_rate: float, + dataset_name: str = "random", + additional_args: list = None, + ) -> Dict: + """Helper method to run throughput benchmarks""" + if additional_args is None: + additional_args = [] + + print(f"Running {test_name} benchmark...") + start_time = time.time() + + try: + # Use the existing server instead of launching a new one + from sglang.bench_serving import run_benchmark + from sglang.test.test_utils import get_benchmark_args + + args = get_benchmark_args( + base_url=self.base_url, + dataset_name=dataset_name, + tokenizer=self.model, + num_prompts=num_prompts, + request_rate=request_rate, + random_input_len=1024, + random_output_len=64, + ) + + # Run benchmark + result = run_benchmark(args) + + elapsed_time = time.time() - start_time + print(f"{test_name} completed in {elapsed_time:.2f}s") + print( + f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s" + ) + + return result + + except Exception as e: + print(f"Benchmark {test_name} failed: {e}") + # Fallback to avoid hard failure; return minimal metrics + return { + "output_throughput": 0.0, + "input_throughput": 0.0, + "mean_ttft_ms": float("inf"), + "mean_latency_ms": float("inf"), + "p99_ttft_ms": float("inf"), + } + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/hicache/test_hicache_storage_e2e.py b/test/srt/hicache/test_hicache_storage_e2e.py new file mode 100644 index 000000000..0c605e633 --- /dev/null +++ b/test/srt/hicache/test_hicache_storage_e2e.py @@ -0,0 +1,286 @@ +""" +E2E tests for HiCache Storage functionality. +Usage: + python3 -m pytest test/srt/hicache/test_hicache_storage_e2e.py -v +""" + +import os +import random +import tempfile +import time +import unittest +from typing import Dict +from urllib.parse import urlparse + +import requests + +from sglang.bench_serving import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class HiCacheStorageBaseTest(CustomTestCase): + """Base test class with common setup and utilities""" + + @classmethod + def setUpClass(cls): + """Set up test environment and launch server once for all tests""" + cls.temp_dir = tempfile.mkdtemp() + cls.model = cls._get_model_name() + cls.base_url = DEFAULT_URL_FOR_TEST + + parsed_url = urlparse(cls.base_url) + cls.base_host = parsed_url.hostname + cls.base_port = str(parsed_url.port) + + # Prepare tokenizer for prompt generation + cls.tokenizer = get_tokenizer(cls.model) + + # Launch server with HiCache enabled and cache report + cls.process = cls._launch_server_with_hicache() + cls._wait_for_server_ready() + + print(f"Test server launched successfully at {cls.base_url}") + print(f"Cache directory: {cls.temp_dir}") + + @classmethod + def tearDownClass(cls): + """Clean up test environment""" + kill_process_tree(cls.process.pid) + + import shutil + + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + @classmethod + def _get_model_name(cls): + """Get model name for the test configuration - override in subclasses""" + return DEFAULT_MODEL_NAME_FOR_TEST + + @classmethod + def _get_base_server_args(cls): + """Get base server arguments - can be extended in subclasses""" + return { + "--enable-hierarchical-cache": True, + "--mem-fraction-static": 0.6, + "--hicache-ratio": 1.2, + "--page-size": 64, + "--enable-cache-report": True, + "--hicache-storage-prefetch-policy": "wait_complete", + "--hicache-storage-backend": "file", + } + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + return {}, {"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir} + + @classmethod + def _launch_server_with_hicache(cls): + """Launch server with HiCache enabled""" + + additional_server_args, env_vars = cls._get_additional_server_args_and_env() + server_args = cls._get_base_server_args() + if additional_server_args: + server_args.update(additional_server_args) + + final_server_args = [] + for k, v in server_args.items(): + if isinstance(v, bool): + final_server_args.append(str(k)) + else: + final_server_args.append(str(k)) + final_server_args.append(str(v)) + + print(f"final_server_args: {final_server_args}") + + env_vars = { + **os.environ, + **env_vars, + } + + return popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=final_server_args, + env=env_vars, + ) + + @classmethod + def _wait_for_server_ready(cls, timeout: int = 60) -> bool: + """Wait for server to be ready""" + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"{cls.base_url}/health", timeout=5) + if response.status_code == 200: + return True + except requests.RequestException: + pass + time.sleep(2) + raise TimeoutError("Server failed to start within timeout") + + def send_request( + self, prompt: str, max_tokens: int = 100, temperature: float = 0.0 + ) -> Dict: + """Send a generate request and return response""" + response = requests.post( + f"{self.base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "ignore_eos": True, + }, + }, + timeout=60, + ) + + self.assertEqual( + response.status_code, + 200, + f"Request failed: {response.status_code} - {response.text}", + ) + return response.json() + + def get_cached_tokens(self, response_json: Dict) -> int: + """Extract cached tokens count from /generate response""" + meta = response_json.get("meta_info", {}) + return int(meta.get("cached_tokens", 0)) + + def flush_cache(self) -> bool: + """Flush device cache to force remote storage access""" + try: + response = requests.post(f"{self.base_url}/flush_cache", timeout=10) + return response.status_code == 200 + except requests.RequestException: + return False + + def gen_prompt(self, token_num: int) -> str: + """Generate a random prompt of specified token length using tokenizer vocabulary. + + This function mimics the implementation from bench_serving.py to create + realistic prompts for testing cache behavior. + """ + all_available_tokens = list(self.tokenizer.get_vocab().values()) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return self.tokenizer.decode(selected_tokens) + + def trigger_offloading_and_flush(self): + """Helper method to trigger offloading and flush cache""" + # Trigger offloading + self.send_request(self.gen_prompt(1), max_tokens=150) + + # Flush device cache to force remote storage access + time.sleep(2) + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + + def test_basic_backup_and_prefetch(self): + """Test storage and retrieval of large context through remote cache""" + print("\n=== Testing Large Context Cache Storage & Retrieval ===") + + # Generate substantial context that will be cached + base_prompt = self.gen_prompt(768) + + # First request - populate cache + print("Step 1: Populating cache with large context...") + response1 = self.send_request(base_prompt, max_tokens=150) + self.assertIsNotNone(response1) + + # Flush device cache to force remote storage access + self.trigger_offloading_and_flush() + + # Second request with extended prompt - should hit remote cache + print("Step 2: Testing cache hit from remote storage...") + extended_prompt = base_prompt + "\n\n" + self.gen_prompt(64) + + start_time = time.time() + response2 = self.send_request(extended_prompt, max_tokens=150) + retrieval_time = time.time() - start_time + + cached_tokens = self.get_cached_tokens(response2) + print( + f"Remote cache retrieval time: {retrieval_time:.3f}s, cached_tokens={cached_tokens}" + ) + + # Assert cached tokens indicate a remote hit + self.assertEqual( + cached_tokens, 768, "Expected significant cached tokens for remote hit" + ) + + +class TestHiCacheStorageTP(HiCacheStorageBaseTest): + """Multi-TP tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = {"--tp-size": 2} + return server_args, {} + + +class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest): + """Layer first direct tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = { + "--hicache-mem-layout": "layer_first", + "--hicache-io-backend": "direct", + } + return server_args, {} + + +class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest): + """Page first layout tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = {"--hicache-mem-layout": "page_first"} + return server_args, {} + + +class TestHiCacheStorageMLA(HiCacheStorageBaseTest): + """MLA Model tests for HiCache Storage functionality""" + + @classmethod + def _get_model_name(cls): + """Use MLA model for testing""" + return DEFAULT_MLA_MODEL_NAME_FOR_TEST + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = {"--tp-size": 2} + return server_args, {} + + +# TODO: Add other backends tests(3fs/mooncake) +# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest): +# """Mooncake backend tests for HiCache Storage functionality""" + +# @classmethod +# def _get_additional_server_args_and_env(cls): +# """Get additional server arguments specific to configuration - override in subclasses""" +# server_args = ["--hicache-storage-backend", "mooncake"] +# env = { +# "MOONCAKE_TE_META_DATA_SERVER": "http://127.0.0.1:8080/metadata", +# "MOONCAKE_MASTER": "127.0.0.1:50051" +# xxxxx +# } +# return server_args, {} + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5b124bb72..047410fe2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -123,6 +123,8 @@ suites = { TestFile("test_dp_attention.py", 277), TestFile("test_patch_torch.py", 19), TestFile("test_release_memory_occupation.py", 127), + TestFile("hicache/test_hicache_storage_e2e.py", 400), + TestFile("hicache/test_hicache_storage_benchmark.py", 400), ], "per-commit-4-gpu": [ TestFile("test_gpt_oss_4gpu.py", 600),