Support page first layout zero copy for mooncake store (#8651)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -260,6 +260,7 @@ class HiCacheController:
|
|||||||
self.storage_backend = MooncakeStore()
|
self.storage_backend = MooncakeStore()
|
||||||
self.get_hash_str = get_hash_str_mooncake
|
self.get_hash_str = get_hash_str_mooncake
|
||||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
|
assert self.mem_pool_host.layout == "page_first"
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||||
|
|||||||
@@ -472,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
)
|
)
|
||||||
for index in range(0, len(indices), self.page_size):
|
for index in range(0, len(indices), self.page_size):
|
||||||
for layer_id in range(self.layer_num):
|
k_ptr = (
|
||||||
k_ptr = (
|
kv_buffer_data_ptr
|
||||||
kv_buffer_data_ptr
|
+ indices[index]
|
||||||
+ indices[index]
|
* self.layer_num
|
||||||
* self.head_num
|
* self.head_num
|
||||||
* self.head_dim
|
* self.head_dim
|
||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
+ layer_id
|
)
|
||||||
* self.size
|
v_ptr = k_ptr + v_offset
|
||||||
* self.head_num
|
ptr_list.append(k_ptr)
|
||||||
* self.head_dim
|
ptr_list.append(v_ptr)
|
||||||
* self.dtype.itemsize
|
key_ = keys[index // self.page_size]
|
||||||
)
|
key_list.append(f"{key_}_k")
|
||||||
v_ptr = k_ptr + v_offset
|
key_list.append(f"{key_}_v")
|
||||||
ptr_list.append(k_ptr)
|
|
||||||
ptr_list.append(v_ptr)
|
|
||||||
key_ = keys[index // self.page_size]
|
|
||||||
key_list.append(f"{key_}_{layer_id}_k")
|
|
||||||
key_list.append(f"{key_}_{layer_id}_v")
|
|
||||||
element_size = (
|
element_size = (
|
||||||
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
self.layer_num
|
||||||
|
* self.dtype.itemsize
|
||||||
|
* self.page_size
|
||||||
|
* self.head_num
|
||||||
|
* self.head_dim
|
||||||
)
|
)
|
||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
@@ -687,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
key_list = []
|
key_list = []
|
||||||
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||||
for index in range(0, len(indices), self.page_size):
|
for index in range(0, len(indices), self.page_size):
|
||||||
for layer_id in range(self.layer_num):
|
k_ptr = (
|
||||||
k_ptr = (
|
kv_buffer_data_ptr
|
||||||
kv_buffer_data_ptr
|
+ indices[index]
|
||||||
+ indices[index]
|
* self.layer_num
|
||||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
+ layer_id
|
)
|
||||||
* self.size
|
ptr_list.append(k_ptr)
|
||||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
key_ = keys[index // self.page_size]
|
||||||
* self.dtype.itemsize
|
key_list.append(f"{key_}_k")
|
||||||
)
|
|
||||||
ptr_list.append(k_ptr)
|
|
||||||
key_ = keys[index // self.page_size]
|
|
||||||
key_list.append(f"{key_}_{layer_id}_k")
|
|
||||||
element_size = (
|
element_size = (
|
||||||
self.dtype.itemsize
|
self.layer_num
|
||||||
|
* self.dtype.itemsize
|
||||||
* self.page_size
|
* self.page_size
|
||||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -223,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
|
|
||||||
def exists(self, keys) -> bool | dict:
|
def exists(self, keys) -> bool | dict:
|
||||||
_keys = []
|
_keys = []
|
||||||
local_rank = torch.cuda.current_device()
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key is None:
|
if key is None:
|
||||||
return None
|
return None
|
||||||
# Since mooncake store is stored in layer by layer,
|
|
||||||
# only the first layer is checked here.
|
_keys.append(f"{key}_k")
|
||||||
_keys.append(f"{key}_{local_rank}_k")
|
|
||||||
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -575,6 +575,11 @@ class ServerArgs:
|
|||||||
"Pipeline parallelism is incompatible with overlap schedule."
|
"Pipeline parallelism is incompatible with overlap schedule."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.hicache_storage_backend == "mooncake":
|
||||||
|
# to use mooncake storage backend, the following conditions must be met:
|
||||||
|
self.hicache_io_backend = "kernel"
|
||||||
|
self.hicache_mem_layout = "page_first"
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
if self.speculative_algorithm == "NEXTN":
|
if self.speculative_algorithm == "NEXTN":
|
||||||
# NEXTN shares the same implementation of EAGLE
|
# NEXTN shares the same implementation of EAGLE
|
||||||
|
|||||||
Reference in New Issue
Block a user