From 0edda32001938b578976409216bc6f9f36f719df Mon Sep 17 00:00:00 2001 From: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Date: Wed, 13 Aug 2025 06:59:26 +0800 Subject: [PATCH] Support page first layout zero copy for mooncake store (#8651) Co-authored-by: Zhiqiang Xie --- .../sglang/srt/managers/cache_controller.py | 1 + .../sglang/srt/mem_cache/memory_pool_host.py | 66 +++++++++---------- .../storage/mooncake_store/mooncake_store.py | 6 +- python/sglang/srt/server_args.py | 5 ++ 4 files changed, 39 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 1edc88751..08e2146af 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -260,6 +260,7 @@ class HiCacheController: self.storage_backend = MooncakeStore() self.get_hash_str = get_hash_str_mooncake self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) + assert self.mem_pool_host.layout == "page_first" elif storage_backend == "hf3fs": from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 83b19375c..cfc7f36c5 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -472,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache): * self.dtype.itemsize ) for index in range(0, len(indices), self.page_size): - for layer_id in range(self.layer_num): - k_ptr = ( - kv_buffer_data_ptr - + indices[index] - * self.head_num - * self.head_dim - * self.dtype.itemsize - + layer_id - * self.size - * self.head_num - * self.head_dim - * self.dtype.itemsize - ) - v_ptr = k_ptr + v_offset - ptr_list.append(k_ptr) - ptr_list.append(v_ptr) - key_ = keys[index // self.page_size] - key_list.append(f"{key_}_{layer_id}_k") - key_list.append(f"{key_}_{layer_id}_v") + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.layer_num + * self.head_num + * self.head_dim + * self.dtype.itemsize + ) + v_ptr = k_ptr + v_offset + ptr_list.append(k_ptr) + ptr_list.append(v_ptr) + key_ = keys[index // self.page_size] + key_list.append(f"{key_}_k") + key_list.append(f"{key_}_v") element_size = ( - self.dtype.itemsize * self.page_size * self.head_num * self.head_dim + self.layer_num + * self.dtype.itemsize + * self.page_size + * self.head_num + * self.head_dim ) element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list @@ -687,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache): key_list = [] kv_buffer_data_ptr = self.kv_buffer.data_ptr() for index in range(0, len(indices), self.page_size): - for layer_id in range(self.layer_num): - k_ptr = ( - kv_buffer_data_ptr - + indices[index] - * (self.kv_lora_rank + self.qk_rope_head_dim) - * self.dtype.itemsize - + layer_id - * self.size - * (self.kv_lora_rank + self.qk_rope_head_dim) - * self.dtype.itemsize - ) - ptr_list.append(k_ptr) - key_ = keys[index // self.page_size] - key_list.append(f"{key_}_{layer_id}_k") + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.layer_num + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + ) + ptr_list.append(k_ptr) + key_ = keys[index // self.page_size] + key_list.append(f"{key_}_k") element_size = ( - self.dtype.itemsize + self.layer_num + * self.dtype.itemsize * self.page_size * (self.kv_lora_rank + self.qk_rope_head_dim) ) diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 38700d55e..51b47335e 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -223,13 +223,11 @@ class MooncakeStore(HiCacheStorage): def exists(self, keys) -> bool | dict: _keys = [] - local_rank = torch.cuda.current_device() for key in keys: if key is None: return None - # Since mooncake store is stored in layer by layer, - # only the first layer is checked here. - _keys.append(f"{key}_{local_rank}_k") + + _keys.append(f"{key}_k") result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} return result diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 457ed17f7..0c76e7d1c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -575,6 +575,11 @@ class ServerArgs: "Pipeline parallelism is incompatible with overlap schedule." ) + if self.hicache_storage_backend == "mooncake": + # to use mooncake storage backend, the following conditions must be met: + self.hicache_io_backend = "kernel" + self.hicache_mem_layout = "page_first" + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE