From a98290aea39407e6bc2e19cc28a70298942e5139 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Mon, 17 Mar 2025 17:45:00 -0700 Subject: [PATCH] Unit test for Hierarchical Caching (#4486) --- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/mem_cache/hiradix_cache.py | 9 +++- python/sglang/srt/mem_cache/memory_pool.py | 6 +-- python/sglang/srt/server_args.py | 8 ++++ test/srt/run_suite.py | 2 + test/srt/test_hicache.py | 44 +++++++++++++++++++ ...ierarchical_mla.py => test_hicache_mla.py} | 0 7 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 test/srt/test_hicache.py rename test/srt/{test_hierarchical_mla.py => test_hicache_mla.py} (100%) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a743af97a..ef4ecaf90 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -445,6 +445,7 @@ class Scheduler(SchedulerOutputProcessorMixin): token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, tp_cache_group=self.tp_worker.get_tp_cpu_group(), page_size=self.page_size, + hicache_ratio=server_args.hicache_ratio, ) else: self.tree_cache = RadixCache( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 24b2056e0..748960f19 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -29,6 +29,7 @@ class HiRadixCache(RadixCache): token_to_kv_pool_allocator: TokenToKVPoolAllocator, tp_cache_group: torch.distributed.ProcessGroup, page_size: int, + hicache_ratio: float, ): if page_size != 1: raise ValueError( @@ -36,9 +37,13 @@ class HiRadixCache(RadixCache): ) self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): - self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache) + self.token_to_kv_pool_host = MHATokenToKVPoolHost( + self.kv_cache, hicache_ratio + ) elif isinstance(self.kv_cache, MLATokenToKVPool): - self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache) + self.token_to_kv_pool_host = MLATokenToKVPoolHost( + self.kv_cache, hicache_ratio + ) else: raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.") diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 28689268c..b1cbb739c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -581,7 +581,7 @@ class HostKVCache(abc.ABC): def __init__( self, device_pool: MHATokenToKVPool, - host_to_device_ratio: float = 3.0, + host_to_device_ratio: float, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): @@ -747,7 +747,7 @@ class MHATokenToKVPoolHost(HostKVCache): def __init__( self, device_pool: MHATokenToKVPool, - host_to_device_ratio: float = 3.0, + host_to_device_ratio: float, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): @@ -789,7 +789,7 @@ class MLATokenToKVPoolHost(HostKVCache): def __init__( self, device_pool: MLATokenToKVPool, - host_to_device_ratio: float = 4.0, + host_to_device_ratio: float, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3d4951a77..77a97b9bc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -173,6 +173,7 @@ class ServerArgs: enable_custom_logit_processor: bool = False tool_call_parser: str = None enable_hierarchical_cache: bool = False + hicache_ratio: float = 2.0 enable_flashinfer_mla: bool = False enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False @@ -1007,6 +1008,13 @@ class ServerArgs: action="store_true", help="Enable hierarchical cache", ) + parser.add_argument( + "--hicache-ratio", + type=float, + required=False, + default=ServerArgs.hicache_ratio, + help="The ratio of the size of host KV cache memory pool to the size of device pool.", + ) # Server warmups parser.add_argument( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index dbf7a7b84..db65318ed 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -74,6 +74,8 @@ suites = { TestFile("test_w8a8_quantization.py", 46), TestFile("test_eval_fp8_accuracy.py", 172), TestFile("test_create_kvindices.py", 2), + TestFile("test_hicache.py", 60), + TestFile("test_hicache_mla.py", 90), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), diff --git a/test/srt/test_hicache.py b/test/srt/test_hicache.py new file mode 100644 index 000000000..0b1d91366 --- /dev/null +++ b/test/srt/test_hicache.py @@ -0,0 +1,44 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestPageSize(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-hierarchical-cache"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_hierarchical_mla.py b/test/srt/test_hicache_mla.py similarity index 100% rename from test/srt/test_hierarchical_mla.py rename to test/srt/test_hicache_mla.py