refactor zero copy (#10300)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -289,8 +289,6 @@ class HiCacheController:
|
||||
)
|
||||
|
||||
self.storage_backend = MooncakeStore(self.storage_config)
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
assert self.mem_pool_host.layout == "page_first"
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||
HiCacheHF3FS,
|
||||
@@ -313,6 +311,8 @@ class HiCacheController:
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
|
||||
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
||||
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
@@ -335,18 +335,10 @@ class HiCacheController:
|
||||
# Select the get and set functions
|
||||
self.page_get_func = self._generic_page_get
|
||||
self.page_set_func = self._generic_page_set
|
||||
self.batch_exists_func = self.storage_backend.batch_exists
|
||||
self.is_3fs_zerocopy = (
|
||||
self.storage_backend_type == "hf3fs"
|
||||
and self.mem_pool_host.layout == "page_first"
|
||||
)
|
||||
if self.storage_backend_type == "mooncake":
|
||||
self.page_get_func = self._mooncake_page_get
|
||||
self.page_set_func = self._mooncake_page_set
|
||||
elif self.is_3fs_zerocopy:
|
||||
self.page_get_func = self._3fs_zero_copy_page_get
|
||||
self.page_set_func = self._3fs_zero_copy_page_set
|
||||
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
||||
|
||||
if self.storage_backend_type in ["hf3fs", "mooncake"]:
|
||||
self.page_get_func = self._page_get_zero_copy
|
||||
self.page_set_func = self._page_set_zero_copy
|
||||
|
||||
self.device = self.mem_pool_device.device
|
||||
self.layer_num = self.mem_pool_device.layer_num
|
||||
@@ -630,42 +622,19 @@ class HiCacheController:
|
||||
for chunk in chunks:
|
||||
self.host_mem_release_queue.put(chunk)
|
||||
|
||||
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
||||
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
||||
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
||||
return hit_page_num
|
||||
|
||||
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
)
|
||||
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||
if page_data:
|
||||
inc = self.page_size * len(hashes) // factor
|
||||
operation.increment(inc)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||
)
|
||||
|
||||
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
host_indices,
|
||||
self.storage_config.tp_rank,
|
||||
)
|
||||
get_result = self.storage_backend.batch_get(
|
||||
key_strs,
|
||||
target_locations=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
if get_result != len(hash_values):
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed or partially failed."
|
||||
)
|
||||
if get_result != 0:
|
||||
operation.increment(get_result * self.page_size)
|
||||
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
||||
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
||||
inc = 0
|
||||
for i in range(len(hash_values)):
|
||||
if not results[i]:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
||||
)
|
||||
break
|
||||
inc += self.page_size
|
||||
operation.increment(inc)
|
||||
|
||||
# todo: deprecate
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||
@@ -755,7 +724,7 @@ class HiCacheController:
|
||||
batch_tokens[i : i + self.page_size], last_hash
|
||||
)
|
||||
batch_hashes.append(last_hash)
|
||||
hit_page_num = self.batch_exists_func(batch_hashes)
|
||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
||||
hash_value.extend(batch_hashes[:hit_page_num])
|
||||
storage_query_count += hit_page_num * self.page_size
|
||||
if hit_page_num < len(batch_hashes):
|
||||
@@ -824,34 +793,16 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
# non-zero copy
|
||||
# todo: deprecate
|
||||
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||
data = [
|
||||
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
||||
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
||||
for i in range(len(hash_values))
|
||||
]
|
||||
return self.storage_backend.batch_set(hash_values, data)
|
||||
|
||||
# zero copy
|
||||
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
host_indices,
|
||||
self.storage_config.tp_rank,
|
||||
)
|
||||
success = self.storage_backend.batch_set(
|
||||
key_strs,
|
||||
target_locations=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
return success
|
||||
|
||||
# zero copy
|
||||
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
)
|
||||
return self.storage_backend.batch_set(hashes, dsts)
|
||||
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
||||
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
||||
|
||||
# Backup batch by batch
|
||||
def _page_backup(self, operation):
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,15 +34,46 @@ class HiCacheStorageConfig:
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiCacheStorageExtraInfo:
|
||||
extra_info: Optional[dict] = None
|
||||
|
||||
|
||||
class HiCacheStorage(ABC):
|
||||
"""
|
||||
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||
"""
|
||||
|
||||
# todo, potentially pass model and TP configs into storage backend
|
||||
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||
|
||||
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
||||
self.mem_pool_host = mem_pool_host
|
||||
|
||||
def batch_get_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
"""
|
||||
Retrieve values for multiple keys.
|
||||
Returns a list of tensors or None for each key.
|
||||
"""
|
||||
pass
|
||||
|
||||
def batch_set_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
"""
|
||||
Retrieve values for multiple keys.
|
||||
Returns a list of tensors or None for each key.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
@@ -54,6 +87,7 @@ class HiCacheStorage(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# TODO: Deprecate
|
||||
@abstractmethod
|
||||
def batch_get(
|
||||
self,
|
||||
@@ -81,6 +115,7 @@ class HiCacheStorage(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# TODO: Deprecate
|
||||
@abstractmethod
|
||||
def batch_set(
|
||||
self,
|
||||
@@ -103,6 +138,7 @@ class HiCacheStorage(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# TODO: Use a finer-grained return type (e.g., List[bool])
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
"""
|
||||
Check if the keys exist in the storage.
|
||||
@@ -114,6 +150,9 @@ class HiCacheStorage(ABC):
|
||||
return i
|
||||
return len(keys)
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
return None
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class HostKVCache(abc.ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Get a flat data page from the host memory pool.
|
||||
"""
|
||||
@@ -461,16 +461,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
else:
|
||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
||||
if self.layout == "layer_first":
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
|
||||
elif self.layout == "page_first":
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
||||
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
|
||||
elif self.layout == "page_first_direct":
|
||||
real_index = index // self.page_size
|
||||
return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten()
|
||||
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
if flat:
|
||||
data_page = data_page.flatten()
|
||||
return data_page
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
@@ -507,9 +510,12 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_buffer_meta(self, keys, indices, local_rank):
|
||||
def get_page_buffer_meta(self, indices):
|
||||
""" "
|
||||
meta data for zero copy
|
||||
"""
|
||||
assert len(indices) % self.page_size == 0
|
||||
ptr_list = []
|
||||
key_list = []
|
||||
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||
indices = indices.tolist()
|
||||
v_offset = (
|
||||
@@ -519,48 +525,52 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
* self.head_dim
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* self.layer_num
|
||||
if self.layout == "layer_first":
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
for layer_id in range(self.layer_num):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* self.head_num
|
||||
* self.head_dim
|
||||
* self.dtype.itemsize
|
||||
+ layer_id
|
||||
* self.size
|
||||
* self.head_num
|
||||
* self.head_dim
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
v_ptr = k_ptr + v_offset
|
||||
ptr_list.append(k_ptr)
|
||||
ptr_list.append(v_ptr)
|
||||
element_size = (
|
||||
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
||||
)
|
||||
element_size_list = [element_size] * len(ptr_list)
|
||||
elif self.layout in ["page_first", "page_first_direct"]:
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* self.layer_num
|
||||
* self.head_num
|
||||
* self.head_dim
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
v_ptr = k_ptr + v_offset
|
||||
ptr_list.append(k_ptr)
|
||||
ptr_list.append(v_ptr)
|
||||
element_size = (
|
||||
self.layer_num
|
||||
* self.dtype.itemsize
|
||||
* self.page_size
|
||||
* self.head_num
|
||||
* self.head_dim
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
v_ptr = k_ptr + v_offset
|
||||
ptr_list.append(k_ptr)
|
||||
ptr_list.append(v_ptr)
|
||||
key_ = keys[index // self.page_size]
|
||||
key_list.append(f"{key_}_{local_rank}_k")
|
||||
key_list.append(f"{key_}_{local_rank}_v")
|
||||
element_size = (
|
||||
self.layer_num
|
||||
* self.dtype.itemsize
|
||||
* self.page_size
|
||||
* self.head_num
|
||||
* self.head_dim
|
||||
)
|
||||
element_size_list = [element_size] * len(key_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def get_buffer_with_hash(self, keys, indices=None):
|
||||
assert self.layout == "page_first"
|
||||
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
||||
|
||||
key_list = []
|
||||
buf_list = []
|
||||
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
key_list.append(f"{key}-k")
|
||||
key_list.append(f"{key}-v")
|
||||
if indices is not None:
|
||||
index = indices[i * self.page_size]
|
||||
buf_list.append(self.k_buffer[index : index + self.page_size])
|
||||
buf_list.append(self.v_buffer[index : index + self.page_size])
|
||||
|
||||
return key_list, buf_list, 2
|
||||
element_size_list = [element_size] * len(ptr_list)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
return ptr_list, element_size_list
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
@@ -736,16 +746,19 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
else:
|
||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
||||
if self.layout == "layer_first":
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
|
||||
elif self.layout == "page_first":
|
||||
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
||||
data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
|
||||
elif self.layout == "page_first_direct":
|
||||
real_index = index // self.page_size
|
||||
return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten()
|
||||
data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
if flat:
|
||||
data_page = data_page.flatten()
|
||||
return data_page
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
@@ -787,40 +800,51 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_buffer_meta(self, keys, indices, local_rank):
|
||||
def get_page_buffer_meta(self, indices):
|
||||
""" "
|
||||
meta data for zero copy
|
||||
"""
|
||||
assert len(indices) % self.page_size == 0
|
||||
ptr_list = []
|
||||
key_list = []
|
||||
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||
indices = indices.tolist()
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* self.layer_num
|
||||
if self.layout == "layer_first":
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
for layer_id in range(self.layer_num):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* self.dtype.itemsize
|
||||
+ layer_id
|
||||
* self.size
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
ptr_list.append(k_ptr)
|
||||
element_size = (
|
||||
self.dtype.itemsize
|
||||
* self.page_size
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
ptr_list.append(k_ptr)
|
||||
key_ = keys[index // self.page_size]
|
||||
key_list.append(f"{key_}_k")
|
||||
element_size = (
|
||||
self.layer_num
|
||||
* self.dtype.itemsize
|
||||
* self.page_size
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
)
|
||||
element_size_list = [element_size] * len(key_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def get_buffer_with_hash(self, keys, indices=None):
|
||||
assert self.layout == "page_first"
|
||||
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
||||
|
||||
buf_list = []
|
||||
|
||||
if indices is not None:
|
||||
for i in range(len(keys)):
|
||||
index = indices[i * self.page_size]
|
||||
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
||||
|
||||
return keys, buf_list, 1
|
||||
element_size_list = [element_size] * len(ptr_list)
|
||||
elif self.layout in ["page_first", "page_first_direct"]:
|
||||
for index in range(0, len(indices), self.page_size):
|
||||
k_ptr = (
|
||||
kv_buffer_data_ptr
|
||||
+ indices[index]
|
||||
* self.layer_num
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* self.dtype.itemsize
|
||||
)
|
||||
ptr_list.append(k_ptr)
|
||||
element_size = (
|
||||
self.layer_num
|
||||
* self.dtype.itemsize
|
||||
* self.page_size
|
||||
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
)
|
||||
element_size_list = [element_size] * len(ptr_list)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
return ptr_list, element_size_list
|
||||
|
||||
@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.hicache_storage import (
|
||||
HiCacheStorage,
|
||||
HiCacheStorageConfig,
|
||||
HiCacheStorageExtraInfo,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
||||
from sglang.srt.metrics.collector import StorageMetrics
|
||||
|
||||
@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
self.skip_backup = True
|
||||
self.rank = 0
|
||||
|
||||
self.is_zero_copy = False
|
||||
|
||||
logger.info(
|
||||
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
||||
f"file_path={self.file_path}, "
|
||||
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
||||
f"num_pages={self.num_pages}"
|
||||
f"num_pages={self.num_pages}, "
|
||||
f"is_mla_model={self.is_mla_model}"
|
||||
)
|
||||
|
||||
self.ac = AtomicCounter(self.numjobs)
|
||||
@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
use_mock_client=use_mock_client,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> torch.Tensor | None:
|
||||
return self.batch_get(
|
||||
[key],
|
||||
[target_location] if target_location is not None else None,
|
||||
[target_sizes] if target_sizes is not None else None,
|
||||
)[0]
|
||||
|
||||
@synchronized()
|
||||
def batch_get(
|
||||
def _batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
values: List[torch.Tensor],
|
||||
) -> List[bool]:
|
||||
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
||||
|
||||
batch_indices, file_offsets = [], []
|
||||
@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(page_index * self.bytes_per_page)
|
||||
|
||||
if target_locations is not None:
|
||||
for target_location in target_locations:
|
||||
assert target_location.is_contiguous()
|
||||
file_results = target_locations
|
||||
else:
|
||||
file_results = [
|
||||
torch.empty(self.numel, dtype=self.dtype)
|
||||
for _ in range(len(batch_indices))
|
||||
]
|
||||
for target_location in values:
|
||||
assert target_location.is_contiguous()
|
||||
file_results = values
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
ionum / (end_time - start_time) * self.gb_per_page
|
||||
)
|
||||
|
||||
results = [None] * len(keys)
|
||||
for batch_index, file_result, read_result in zip(
|
||||
batch_indices, file_results, read_results
|
||||
):
|
||||
results = [False] * len(keys)
|
||||
for batch_index, read_result in zip(batch_indices, read_results):
|
||||
if read_result == self.bytes_per_page:
|
||||
results[batch_index] = file_result
|
||||
results[batch_index] = True
|
||||
else:
|
||||
logger.error(
|
||||
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
||||
@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
return results
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[Any] = None,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
return self.batch_set(
|
||||
[key],
|
||||
[value] if value is not None else None,
|
||||
[target_location] if target_location is not None else None,
|
||||
[target_sizes] if target_sizes is not None else None,
|
||||
)
|
||||
|
||||
@synchronized()
|
||||
def batch_set(
|
||||
def _batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
values: Optional[Any] = None,
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
) -> List[bool]:
|
||||
# In MLA backend, only one rank needs to backup the KV cache
|
||||
if self.skip_backup:
|
||||
return True
|
||||
@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
self.rank, written_keys_to_confirm, pages_to_release
|
||||
)
|
||||
|
||||
return all(results)
|
||||
return results
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.metadata_client.delete_keys(self.rank, [key])
|
||||
@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
return result[0] if result else False
|
||||
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
factor = 1
|
||||
if self.is_zero_copy and not self.is_mla_model:
|
||||
keys = self._get_mha_zero_copy_keys(keys)
|
||||
factor = 2
|
||||
|
||||
results = self.metadata_client.exists(self.rank, keys)
|
||||
for i in range(len(keys)):
|
||||
if not results[i]:
|
||||
return i
|
||||
|
||||
return len(keys)
|
||||
i = 0
|
||||
while i < len(keys) and results[i]:
|
||||
i += 1
|
||||
|
||||
def clear(self) -> bool:
|
||||
return i // factor
|
||||
|
||||
def clear(self) -> None:
|
||||
try:
|
||||
self.metadata_client.clear(self.rank)
|
||||
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
@@ -521,3 +496,139 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
self.prefetch_bandwidth.clear()
|
||||
self.backup_bandwidth.clear()
|
||||
return storage_metrics
|
||||
|
||||
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
||||
super().register_mem_pool_host(mem_pool_host)
|
||||
self.is_zero_copy = self.mem_pool_host.layout == "page_first"
|
||||
logger.info(f"{self.is_zero_copy=}")
|
||||
|
||||
def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
|
||||
_keys = []
|
||||
for k in keys:
|
||||
_keys.append(f"{k}-k")
|
||||
_keys.append(f"{k}-v")
|
||||
return _keys
|
||||
|
||||
def _get_mha_zero_copy_values(
|
||||
self, values: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
_values = []
|
||||
for value in values:
|
||||
_values.append(value[0])
|
||||
_values.append(value[1])
|
||||
return _values
|
||||
|
||||
def _batch_get_preprocess(self, keys, host_indices):
|
||||
page_num = len(host_indices) // self.mem_pool_host.page_size
|
||||
# host_indices to kv_buffer
|
||||
flat = not self.is_zero_copy
|
||||
values = (
|
||||
[
|
||||
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
|
||||
for i in range(page_num)
|
||||
]
|
||||
if self.is_zero_copy
|
||||
else [
|
||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
|
||||
]
|
||||
)
|
||||
|
||||
if self.is_zero_copy and not self.is_mla_model:
|
||||
keys = self._get_mha_zero_copy_keys(keys)
|
||||
values = self._get_mha_zero_copy_values(values)
|
||||
|
||||
return keys, values
|
||||
|
||||
def _batch_get_postprocess(self, host_indices, values, results):
|
||||
page_num = len(host_indices) // self.mem_pool_host.page_size
|
||||
|
||||
if self.is_zero_copy:
|
||||
if not self.is_mla_model:
|
||||
results = [
|
||||
(results[2 * i] and results[2 * i + 1]) for i in range(page_num)
|
||||
]
|
||||
results = results[:page_num]
|
||||
return results
|
||||
|
||||
for i in range(page_num):
|
||||
if not results[i]:
|
||||
break
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
host_indices[i * self.mem_pool_host.page_size], values[i]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def batch_get_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
keys, values = self._batch_get_preprocess(keys, host_indices)
|
||||
results = self._batch_get(keys, values)
|
||||
return self._batch_get_postprocess(host_indices, values, results)
|
||||
|
||||
def _batch_set_preprocess(self, keys, host_indices):
|
||||
page_num = len(host_indices) // self.mem_pool_host.page_size
|
||||
# host_indices to kv_buffer
|
||||
flat = not self.is_zero_copy
|
||||
values = [
|
||||
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
|
||||
for i in range(page_num)
|
||||
]
|
||||
|
||||
if self.is_zero_copy and not self.is_mla_model:
|
||||
keys = self._get_mha_zero_copy_keys(keys)
|
||||
values = self._get_mha_zero_copy_values(values)
|
||||
|
||||
return keys, values
|
||||
|
||||
def batch_set_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
len_keys = len(keys)
|
||||
keys, values = self._batch_set_preprocess(keys, host_indices)
|
||||
results = self._batch_set(keys, values)
|
||||
return results
|
||||
|
||||
# Deprecated
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> torch.Tensor | None:
|
||||
pass
|
||||
|
||||
# Deprecated
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[torch.Tensor | None] | int:
|
||||
pass
|
||||
|
||||
# Deprecated
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[Any] = None,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
# Deprecated
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
values: Optional[Any] = None,
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,12 @@ from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.hicache_storage import (
|
||||
HiCacheStorage,
|
||||
HiCacheStorageConfig,
|
||||
HiCacheStorageExtraInfo,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
||||
@@ -183,7 +188,12 @@ class MooncakeStore(HiCacheStorage):
|
||||
assert self.store.is_exist(warmup_key) == 1
|
||||
assert self.store.get(warmup_key) == warmup_value
|
||||
|
||||
def register_buffer(self, buffer: torch.Tensor) -> None:
|
||||
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
||||
super().register_mem_pool_host(mem_pool_host)
|
||||
assert (
|
||||
self.mem_pool_host.layout == "page_first"
|
||||
), "mooncake store storage backend only support page first layout"
|
||||
buffer = self.mem_pool_host.kv_buffer
|
||||
try:
|
||||
buffer_ptr = buffer.data_ptr()
|
||||
buffer_size = buffer.numel() * buffer.element_size()
|
||||
@@ -194,6 +204,97 @@ class MooncakeStore(HiCacheStorage):
|
||||
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
||||
|
||||
def _get_mha_buffer_meta(self, keys, indices):
|
||||
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
||||
key_list = []
|
||||
for key_ in keys:
|
||||
key_list.append(f"{key_}_{self.local_rank}_k")
|
||||
key_list.append(f"{key_}_{self.local_rank}_v")
|
||||
assert len(key_list) == len(ptr_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def _get_mla_buffer_meta(self, keys, indices):
|
||||
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
||||
key_list = []
|
||||
for key_ in keys:
|
||||
key_list.append(f"{key_}_k")
|
||||
assert len(key_list) == len(ptr_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def _batch_preprocess(self, keys, host_indices):
|
||||
assert len(keys) > 0
|
||||
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
|
||||
if self.is_mla_backend:
|
||||
return self._get_mla_buffer_meta(keys, host_indices)
|
||||
else:
|
||||
return self._get_mha_buffer_meta(keys, host_indices)
|
||||
|
||||
def _batch_postprocess(self, results: List[int], is_set_operate=False):
|
||||
"""
|
||||
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
|
||||
for batch_get_into, results is Vector of integers,
|
||||
where each element is the number of bytes read on success, or a negative value on error
|
||||
for batch_put_from, results is Vector of integers,
|
||||
where each element is 0 on success, or a negative value on error
|
||||
"""
|
||||
if self.is_mla_backend:
|
||||
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
|
||||
else:
|
||||
kv_pairs = zip(results[::2], results[1::2])
|
||||
return [
|
||||
(
|
||||
(k_res == 0 and v_res == 0)
|
||||
if is_set_operate
|
||||
else (k_res > 0 and v_res > 0)
|
||||
)
|
||||
for k_res, v_res in kv_pairs
|
||||
]
|
||||
|
||||
def batch_get_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
||||
get_results = self._get_batch_zero_copy_impl(
|
||||
key_strs, buffer_ptrs, buffer_sizes
|
||||
)
|
||||
return self._batch_postprocess(get_results, is_set_operate=False)
|
||||
|
||||
def batch_set_v1(
|
||||
self,
|
||||
keys: List[str],
|
||||
host_indices: torch.Tensor,
|
||||
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
||||
) -> List[bool]:
|
||||
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
||||
exist_result = self._batch_exist(key_strs)
|
||||
|
||||
set_keys = []
|
||||
set_buffer_ptrs = []
|
||||
set_buffer_sizes = []
|
||||
set_indices = []
|
||||
set_results = [-1] * len(keys)
|
||||
for i in range(len(keys)):
|
||||
if exist_result[i] != 1:
|
||||
set_keys.append(keys[i])
|
||||
set_buffer_ptrs.append(buffer_ptrs[i])
|
||||
set_buffer_sizes.append(buffer_sizes[i])
|
||||
set_indices.append(i)
|
||||
else:
|
||||
set_results[i] = 0
|
||||
|
||||
# Only set non-existing keys to storage
|
||||
if len(set_keys) > 0:
|
||||
put_results = self._put_batch_zero_copy_impl(
|
||||
key_strs, buffer_ptrs, buffer_sizes
|
||||
)
|
||||
for i in range(len(set_indices)):
|
||||
set_results[set_indices[i]] = put_results[i]
|
||||
|
||||
return self._batch_postprocess(set_results, is_set_operate=True)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user