diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index cab5f5ae2..079dc0a64 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -466,6 +466,9 @@ class MHATokenToKVPoolHost(HostKVCache): return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() elif self.layout == "page_first": return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten() + elif self.layout == "page_first_direct": + real_index = index // self.page_size + return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten() else: raise ValueError(f"Unsupported layout: {self.layout}") @@ -494,6 +497,13 @@ class MHATokenToKVPoolHost(HostKVCache): 2, self.page_size, self.layer_num, self.head_num, self.head_dim ) ) + elif self.layout == "page_first_direct": + real_index = index // self.page_size + self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = ( + data_page.reshape( + 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim + ) + ) else: raise ValueError(f"Unsupported layout: {self.layout}") @@ -731,6 +741,9 @@ class MLATokenToKVPoolHost(HostKVCache): return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() elif self.layout == "page_first": return self.kv_buffer[index : index + self.page_size, :, :, :].flatten() + elif self.layout == "page_first_direct": + real_index = index // self.page_size + return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten() else: raise ValueError(f"Unsupported layout: {self.layout}") @@ -762,6 +775,15 @@ class MLATokenToKVPoolHost(HostKVCache): 1, self.kv_lora_rank + self.qk_rope_head_dim, ) + elif self.layout == "page_first_direct": + real_index = index // self.page_size + self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape( + 1, + self.layer_num, + self.page_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) else: raise ValueError(f"Unsupported layout: {self.layout}") diff --git a/test/srt/hicache/test_hicache_storage_file_backend.py b/test/srt/hicache/test_hicache_storage_file_backend.py index fc8a0e25d..e257fbc27 100644 --- a/test/srt/hicache/test_hicache_storage_file_backend.py +++ b/test/srt/hicache/test_hicache_storage_file_backend.py @@ -238,6 +238,19 @@ class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCa return server_args, {} +class TestHiCacheStoragePageFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase): + """Page first direct tests for HiCache Storage functionality""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args = { + "--hicache-mem-layout": "page_first_direct", + "--hicache-io-backend": "direct", + } + return server_args, {} + + class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase): """Page first layout tests for HiCache Storage functionality"""