diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py new file mode 100644 index 000000000..c7a485fa0 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py @@ -0,0 +1,164 @@ +import logging +import os +import threading +from abc import ABC, abstractmethod +from typing import List + +import torch + + +class Hf3fsClient(ABC): + """Abstract interface for HF3FS clients.""" + + @abstractmethod + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + """Initialize the HF3FS client. + + Args: + path: File path for storage + size: Total size of storage file + bytes_per_page: Bytes per page + entries: Number of entries for batch operations + """ + pass + + @abstractmethod + def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + """Batch read from storage.""" + pass + + @abstractmethod + def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + """Batch write to storage.""" + pass + + @abstractmethod + def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None: + """Validate batch operation parameters.""" + pass + + @abstractmethod + def get_size(self) -> int: + """Get total storage size.""" + pass + + @abstractmethod + def close(self) -> None: + """Close the client and cleanup resources.""" + pass + + @abstractmethod + def flush(self) -> None: + """Flush data to disk.""" + pass + + +logger = logging.getLogger(__name__) + + +class Hf3fsMockClient(Hf3fsClient): + """Mock implementation of Hf3fsClient for CI testing purposes.""" + + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + """Initialize mock HF3FS client.""" + self.path = path + self.size = size + self.bytes_per_page = bytes_per_page + self.entries = entries + + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self.path), exist_ok=True) + + # Create and initialize the file + self.file = os.open(self.path, os.O_RDWR | os.O_CREAT) + os.ftruncate(self.file, size) + + logger.info( + f"Hf3fsMockClient initialized: path={path}, size={size}, " + f"bytes_per_page={bytes_per_page}, entries={entries}" + ) + + def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + """Batch read from mock storage.""" + self.check(offsets, tensors) + + results = [] + + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + + try: + os.lseek(self.file, offset, os.SEEK_SET) + bytes_read = os.read(self.file, size) + + if len(bytes_read) == size: + # Convert bytes to tensor and copy to target + bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8) + typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape) + tensor.copy_(typed_tensor) + results.append(size) + else: + logger.warning( + f"Short read: expected {size}, got {len(bytes_read)}" + ) + results.append(len(bytes_read)) + + except Exception as e: + logger.error(f"Error reading from offset {offset}: {e}") + results.append(0) + + return results + + def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + """Batch write to mock storage.""" + self.check(offsets, tensors) + + results = [] + + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + + try: + # Convert tensor to bytes and write directly to file + tensor_bytes = tensor.contiguous().view(torch.uint8).flatten() + data = tensor_bytes.numpy().tobytes() + + os.lseek(self.file, offset, os.SEEK_SET) + bytes_written = os.write(self.file, data) + + if bytes_written == size: + results.append(size) + else: + logger.warning(f"Short write: expected {size}, got {bytes_written}") + results.append(bytes_written) + + except Exception as e: + logger.error(f"Error writing to offset {offset}: {e}") + results.append(0) + + return results + + def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None: + """Validate batch operation parameters.""" + pass + + def get_size(self) -> int: + """Get total storage size.""" + return self.size + + def close(self) -> None: + """Close the mock client and cleanup resources.""" + try: + if hasattr(self, "file") and self.file >= 0: + os.close(self.file) + self.file = -1 # Mark as closed + logger.info(f"MockHf3fsClient closed: {self.path}") + except Exception as e: + logger.error(f"Error closing MockHf3fsClient: {e}") + + def flush(self) -> None: + """Flush data to disk.""" + try: + os.fsync(self.file) + except Exception as e: + logger.error(f"Error flushing MockHf3fsClient: {e}") diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py similarity index 96% rename from python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py rename to python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py index 399a90118..480c18ed1 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py @@ -9,6 +9,8 @@ from typing import List import torch from torch.utils.cpp_extension import load +from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient + root = Path(__file__).parent.resolve() hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"]) @@ -51,7 +53,9 @@ def wsynchronized(): return _decorator -class Hf3fsClient: +class Hf3fsUsrBioClient(Hf3fsClient): + """HF3FS client implementation using usrbio.""" + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): if not HF3FS_AVAILABLE: raise ImportError( diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 7f64eb837..9595e7204 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -13,7 +13,7 @@ from typing import Any, List, Optional, Tuple import torch from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig -from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient +from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient from sglang.srt.metrics.collector import StorageMetrics logger = logging.getLogger(__name__) @@ -114,6 +114,33 @@ def synchronized(): return _decorator +def create_hf3fs_client( + path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False +) -> Hf3fsClient: + """Factory function to create appropriate HF3FS client. + + Args: + path: File path for storage + size: Total size of storage file + bytes_per_page: Bytes per page + entries: Number of entries for batch operations + use_mock: Whether to use mock client instead of real usrbio client + + Returns: + """ + if use_mock: + from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient + + logger.info(f"[Rank Using Hf3fsMockClient for testing") + return Hf3fsMockClient(path, size, bytes_per_page, entries) + else: + from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import ( + Hf3fsUsrBioClient, + ) + + return Hf3fsUsrBioClient(path, size, bytes_per_page, entries) + + class HiCacheHF3FS(HiCacheStorage): """HiCache backend that stores KV cache pages in HF3FS files.""" @@ -131,6 +158,7 @@ class HiCacheHF3FS(HiCacheStorage): metadata_client: Hf3fsMetadataInterface, is_mla_model: bool = False, is_page_first_layout: bool = False, + use_mock_client: bool = False, ): self.rank = rank self.file_path = file_path @@ -159,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage): self.ac = AtomicCounter(self.numjobs) self.clients = [ - Hf3fsClient( - self.file_path, self.file_size, self.bytes_per_page, self.entries + create_hf3fs_client( + self.file_path, + self.file_size, + self.bytes_per_page, + self.entries, + use_mock_client, ) for _ in range(numjobs) ] @@ -202,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage): Hf3fsLocalMetadataClient, ) + use_mock_client = False if storage_config is not None: rank, is_mla_model, is_page_first_layout = ( storage_config.tp_rank, storage_config.is_mla_model, storage_config.is_page_first_layout, ) + + if storage_config.extra_config is not None: + use_mock_client = storage_config.extra_config.get( + "use_mock_hf3fs_client", False + ) else: - rank, is_mla_model, is_page_first_layout = 0, False, False + rank, is_mla_model, is_page_first_layout = ( + 0, + False, + False, + ) mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md" @@ -228,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage): dtype=dtype, metadata_client=Hf3fsLocalMetadataClient(), is_page_first_layout=is_page_first_layout, + use_mock_client=use_mock_client, ) try: @@ -277,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage): metadata_client=metadata_client, is_mla_model=is_mla_model, is_page_first_layout=is_page_first_layout, + use_mock_client=use_mock_client, ) def get( diff --git a/test/srt/hicache/test_hicache_storage_3fs_backend.py b/test/srt/hicache/test_hicache_storage_3fs_backend.py new file mode 100644 index 000000000..d0f519075 --- /dev/null +++ b/test/srt/hicache/test_hicache_storage_3fs_backend.py @@ -0,0 +1,135 @@ +""" +Benchmark tests for HiCache Storage with 3FS backend. +Usage: + python3 -m pytest test/srt/hicache/test_hicache_storage_3fs_backend.py -v +""" + +import json +import os +import time +import unittest +from types import SimpleNamespace + +from test_hicache_storage_file_backend import HiCacheStorageBaseMixin + +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import CustomTestCase + + +class HiCacheStorage3FSBackendBaseMixin(HiCacheStorageBaseMixin): + """Base mixin class with common setup and utilities""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + # Create a temporary JSON config file for HF3FS + hf3fs_config = { + "file_path_prefix": os.path.join(cls.temp_dir, "hicache"), + "file_size": 1024 * 1024 * 1024 * 2, + "numjobs": 2, + "entries": 8, + "use_mock_hf3fs_client": True, + } + + # Write config to temporary file + config_file = os.path.join(cls.temp_dir, "hf3fs_config.json") + with open(config_file, "w") as f: + json.dump(hf3fs_config, f, indent=2) + + server_args = { + "--tp-size": 1, + "--hicache-ratio": 1.2, + "--hicache-storage-backend": "hf3fs", + "--hicache-storage-backend-extra-config": json.dumps(hf3fs_config), + } + + # Set the environment variable to point to our config file + env_vars = { + "SGLANG_HICACHE_HF3FS_CONFIG_PATH": config_file, + } + + return server_args, env_vars + + +class TestHf3fsBackendLayerFirstLayout( + HiCacheStorage3FSBackendBaseMixin, CustomTestCase +): + """Layer first layout tests for HiCache-Hf3fs backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-mem-layout"] = "layer_first" + server_args["--hicache-io-backend"] = "direct" + return server_args, env_vars + + +class TestHf3fsBackendPageFirstLayout( + HiCacheStorage3FSBackendBaseMixin, CustomTestCase +): + """Page first layout tests for HiCache-Hf3fs backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-mem-layout"] = "page_first" + return server_args, env_vars + + +class TestHf3fsBackendAccuracy(HiCacheStorage3FSBackendBaseMixin, CustomTestCase): + """Accuracy tests for HiCache-Hf3fs backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-ratio"] = 1.5 + server_args["--tp-size"] = 2 + return server_args, env_vars + + def test_eval_accuracy(self): + """Test eval accuracy with cache persistence across cache flushes""" + print("\n=== Testing Eval Accuracy with Cache Persistence ===") + + # First evaluation - populate cache + print("Phase 1: Running initial GSM8K evaluation to populate cache...") + args_initial = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=50, + max_new_tokens=512, + parallel=10, + host=f"http://{self.base_host}", + port=int(self.base_port), + ) + metrics_initial = run_eval_few_shot_gsm8k(args_initial) + + # Flush cache to force remote storage access + print("Phase 2: Flushing device cache...") + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + time.sleep(2) + + # Second evaluation - should use remote cache + print("Phase 3: Running second GSM8K evaluation using remote cache...") + metrics_cached = run_eval_few_shot_gsm8k(args_initial) + + # Verify accuracy consistency + accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + print(f"Accuracy difference: {accuracy_diff:.4f}") + + # Assertions + self.assertGreater( + metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable" + ) + self.assertGreater( + metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable" + ) + self.assertLess( + accuracy_diff, 0.05, "Accuracy should be consistent between cache states" + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/hicache/test_hicache_storage_benchmark.py b/test/srt/hicache/test_hicache_storage_benchmark.py deleted file mode 100644 index 0c9206afb..000000000 --- a/test/srt/hicache/test_hicache_storage_benchmark.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Benchmark tests for HiCache Storage functionality. -Usage: - python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v -""" - -import time -import unittest -from types import SimpleNamespace -from typing import Dict - -import requests -from test_hicache_storage_e2e import HiCacheStorageBaseTest - -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import is_in_ci, write_github_step_summary - - -class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest): - """Benchmark tests for HiCache Storage functionality""" - - @classmethod - def _get_additional_server_args_and_env(cls): - """Get additional server arguments specific to configuration - override in subclasses""" - server_args = {"--tp-size": 2, "--hicache-ratio": 1.5} - return server_args, {} - - def flush_cache(self) -> bool: - """Flush device cache to force remote storage access""" - try: - response = requests.post(f"{self.base_url}/flush_cache", timeout=10) - return response.status_code == 200 - except requests.RequestException: - return False - - # === Accuracy Tests === - def test_eval_accuracy_with_cache_persistence(self): - """Test eval accuracy with cache persistence across cache flushes""" - print("\n=== Testing Eval Accuracy with Cache Persistence ===") - - # First evaluation - populate cache - print("Phase 1: Running initial GSM8K evaluation to populate cache...") - args_initial = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=400, - max_new_tokens=512, - parallel=32, - host=f"http://{self.base_host}", - port=int(self.base_port), - ) - metrics_initial = run_eval_few_shot_gsm8k(args_initial) - print(f"Evaluation metrics: {metrics_initial}") - self.assertGreater(metrics_initial["accuracy"], 0.60) - - # Flush cache to force remote storage access - print("Phase 2: Flushing device cache...") - self.assertTrue(self.flush_cache(), "Cache flush should succeed") - time.sleep(2) - - # Second evaluation - should use remote cache - print("Phase 3: Running second GSM8K evaluation using remote cache...") - - start_time = time.time() - metrics_cached = run_eval_few_shot_gsm8k(args_initial) - cached_time = time.time() - start_time - - print(f"Cached evaluation completed in {cached_time:.2f}s") - print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}") - print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s") - - # Verify accuracy consistency - accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) - print(f"Accuracy difference: {accuracy_diff:.4f}") - - # Assertions - self.assertGreater( - metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable" - ) - self.assertGreater( - metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable" - ) - self.assertLess( - accuracy_diff, 0.05, "Accuracy should be consistent between cache states" - ) - - # Performance should be similar or better with cache - throughput_ratio = ( - metrics_cached["output_throughput"] / metrics_initial["output_throughput"] - ) - print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}") - - if is_in_ci(): - write_github_step_summary( - f"### HiCache Storage Accuracy Test\n" - f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n" - f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n" - f"Accuracy difference: {accuracy_diff:.4f}\n" - f"Throughput ratio: {throughput_ratio:.2f}\n" - ) - - # === Performance Benchmark Tests === - - def test_throughput_benchmark_with_hicache(self): - """Benchmark throughput performance with HiCache enabled""" - print("\n=== Benchmarking Throughput with HiCache ===") - - # throughput test - res1 = self._run_throughput_benchmark( - test_name="hicache_offline_throughput", - num_prompts=200, - request_rate=10, - additional_args=[], - ) - - # Flush cache to force remote storage access - print("Phase 2: Flushing device cache...") - self.assertTrue(self.flush_cache(), "Cache flush should succeed") - time.sleep(2) - - # Second benchmark, should use remote cache - res2 = self._run_throughput_benchmark( - test_name="hicache_online_throughput", - num_prompts=400, - request_rate=10, - additional_args=[], - ) - - if is_in_ci(): - write_github_step_summary( - f"### HiCache Storage FileBackend Benchmark Test\n" - f"First time throughput: {res1['input_throughput']:.2f} token/s\n" - f"Second time throughput: {res2['input_throughput']:.2f} token/s\n" - f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n" - f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n" - ) - - def _run_throughput_benchmark( - self, - test_name: str, - num_prompts: int, - request_rate: float, - dataset_name: str = "random", - additional_args: list = None, - ) -> Dict: - """Helper method to run throughput benchmarks""" - if additional_args is None: - additional_args = [] - - print(f"Running {test_name} benchmark...") - start_time = time.time() - - try: - # Use the existing server instead of launching a new one - from sglang.bench_serving import run_benchmark - from sglang.test.test_utils import get_benchmark_args - - args = get_benchmark_args( - base_url=self.base_url, - dataset_name=dataset_name, - tokenizer=self.model, - num_prompts=num_prompts, - request_rate=request_rate, - random_input_len=1024, - random_output_len=64, - ) - - # Run benchmark - result = run_benchmark(args) - - elapsed_time = time.time() - start_time - print(f"{test_name} completed in {elapsed_time:.2f}s") - print( - f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s" - ) - - return result - - except Exception as e: - print(f"Benchmark {test_name} failed: {e}") - # Fallback to avoid hard failure; return minimal metrics - return { - "output_throughput": 0.0, - "input_throughput": 0.0, - "mean_ttft_ms": float("inf"), - "mean_latency_ms": float("inf"), - "p99_ttft_ms": float("inf"), - } - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/test/srt/hicache/test_hicache_storage_e2e.py b/test/srt/hicache/test_hicache_storage_file_backend.py similarity index 77% rename from test/srt/hicache/test_hicache_storage_e2e.py rename to test/srt/hicache/test_hicache_storage_file_backend.py index 0c605e633..fc8a0e25d 100644 --- a/test/srt/hicache/test_hicache_storage_e2e.py +++ b/test/srt/hicache/test_hicache_storage_file_backend.py @@ -9,6 +9,7 @@ import random import tempfile import time import unittest +from types import SimpleNamespace from typing import Dict from urllib.parse import urlparse @@ -16,6 +17,7 @@ import requests from sglang.bench_serving import get_tokenizer from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, @@ -26,8 +28,8 @@ from sglang.test.test_utils import ( ) -class HiCacheStorageBaseTest(CustomTestCase): - """Base test class with common setup and utilities""" +class HiCacheStorageBaseMixin: + """Base mixin class with common setup and utilities""" @classmethod def setUpClass(cls): @@ -166,11 +168,7 @@ class HiCacheStorageBaseTest(CustomTestCase): return False def gen_prompt(self, token_num: int) -> str: - """Generate a random prompt of specified token length using tokenizer vocabulary. - - This function mimics the implementation from bench_serving.py to create - realistic prompts for testing cache behavior. - """ + """Generate a random prompt of specified token length using tokenizer vocabulary.""" all_available_tokens = list(self.tokenizer.get_vocab().values()) selected_tokens = random.choices(all_available_tokens, k=token_num) return self.tokenizer.decode(selected_tokens) @@ -201,10 +199,9 @@ class HiCacheStorageBaseTest(CustomTestCase): # Second request with extended prompt - should hit remote cache print("Step 2: Testing cache hit from remote storage...") - extended_prompt = base_prompt + "\n\n" + self.gen_prompt(64) start_time = time.time() - response2 = self.send_request(extended_prompt, max_tokens=150) + response2 = self.send_request(base_prompt, max_tokens=150) retrieval_time = time.time() - start_time cached_tokens = self.get_cached_tokens(response2) @@ -213,12 +210,12 @@ class HiCacheStorageBaseTest(CustomTestCase): ) # Assert cached tokens indicate a remote hit - self.assertEqual( - cached_tokens, 768, "Expected significant cached tokens for remote hit" + self.assertGreater( + cached_tokens, 700, "Expected significant cached tokens for remote hit" ) -class TestHiCacheStorageTP(HiCacheStorageBaseTest): +class TestHiCacheStorageTP(HiCacheStorageBaseMixin, CustomTestCase): """Multi-TP tests for HiCache Storage functionality""" @classmethod @@ -228,7 +225,7 @@ class TestHiCacheStorageTP(HiCacheStorageBaseTest): return server_args, {} -class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest): +class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase): """Layer first direct tests for HiCache Storage functionality""" @classmethod @@ -241,7 +238,7 @@ class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest): return server_args, {} -class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest): +class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase): """Page first layout tests for HiCache Storage functionality""" @classmethod @@ -251,7 +248,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest): return server_args, {} -class TestHiCacheStorageMLA(HiCacheStorageBaseTest): +class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase): """MLA Model tests for HiCache Storage functionality""" @classmethod @@ -266,6 +263,57 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseTest): return server_args, {} +class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase): + """Accuracy tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = {"--tp-size": 2, "--hicache-ratio": 1.5} + return server_args, {} + + def test_eval_accuracy(self): + """Test eval accuracy with cache persistence across cache flushes""" + print("\n=== Testing Eval Accuracy with Cache Persistence ===") + + # First evaluation - populate cache + print("Phase 1: Running initial GSM8K evaluation to populate cache...") + args_initial = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=50, + max_new_tokens=512, + parallel=10, + host=f"http://{self.base_host}", + port=int(self.base_port), + ) + metrics_initial = run_eval_few_shot_gsm8k(args_initial) + + # Flush cache to force remote storage access + print("Phase 2: Flushing device cache...") + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + time.sleep(2) + + # Second evaluation - should use remote cache + print("Phase 3: Running second GSM8K evaluation using remote cache...") + metrics_cached = run_eval_few_shot_gsm8k(args_initial) + + # Verify accuracy consistency + accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + print(f"Accuracy difference: {accuracy_diff:.4f}") + + # Assertions + self.assertGreater( + metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable" + ) + self.assertGreater( + metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable" + ) + self.assertLess( + accuracy_diff, 0.05, "Accuracy should be consistent between cache states" + ) + + # TODO: Add other backends tests(3fs/mooncake) # class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest): # """Mooncake backend tests for HiCache Storage functionality""" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a918a6339..28ab321a0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -125,8 +125,8 @@ suites = { TestFile("test_dp_attention.py", 277), TestFile("test_patch_torch.py", 19), TestFile("test_release_memory_occupation.py", 127), - TestFile("hicache/test_hicache_storage_e2e.py", 400), - TestFile("hicache/test_hicache_storage_benchmark.py", 400), + TestFile("hicache/test_hicache_storage_file_backend.py", 400), + TestFile("hicache/test_hicache_storage_3fs_backend.py", 400), ], "per-commit-4-gpu": [ TestFile("test_gpt_oss_4gpu.py", 600),