Support l3 cache (mooncake store) for hiradix cache (#7211)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com> Co-authored-by: zuoyuan <zhangzuo21@mails.tsinghua.edu.cn> Co-authored-by: @wangyueneng.wyn <wangyueneng.wyn@antgroup.com> Co-authored-by: JinYan Su <jinyansu792@gmail.com>
This commit is contained in:
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||||
|
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
||||||
|
MooncakeStore,
|
||||||
|
get_hash_str_mooncake,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -125,7 +129,7 @@ class TransferBuffer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
|
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.buffers = Queue(maxsize=buffer_count)
|
self.buffers = Queue(maxsize=buffer_count)
|
||||||
@@ -260,6 +264,11 @@ class HiCacheController:
|
|||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile()
|
self.storage_backend = HiCacheFile()
|
||||||
|
self.get_hash_str = get_hash_str
|
||||||
|
elif storage_backend == "mooncake":
|
||||||
|
self.storage_backend = MooncakeStore()
|
||||||
|
self.get_hash_str = get_hash_str_mooncake
|
||||||
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
@@ -271,6 +280,7 @@ class HiCacheController:
|
|||||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
rank, bytes_per_page, dtype
|
rank, bytes_per_page, dtype
|
||||||
)
|
)
|
||||||
|
self.get_hash_str = get_hash_str
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported storage backend: {storage_backend}"
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
@@ -532,6 +542,37 @@ class HiCacheController:
|
|||||||
operation.mark_done()
|
operation.mark_done()
|
||||||
return operation.completed_tokens, operation.hash_value
|
return operation.completed_tokens, operation.hash_value
|
||||||
|
|
||||||
|
def generic_page_transfer(self, operation, batch_size=8):
|
||||||
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
|
page_hashes = operation.hash_value[i : i + batch_size]
|
||||||
|
page_data = self.storage_backend.batch_get(page_hashes)
|
||||||
|
if page_data is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
completed_tokens = operation.completed_tokens
|
||||||
|
if operation.increment(self.page_size * len(page_hashes)):
|
||||||
|
for i in range(len(page_hashes)):
|
||||||
|
self.mem_pool_host.set_from_flat_data_page(
|
||||||
|
operation.host_indices[completed_tokens],
|
||||||
|
page_data[i],
|
||||||
|
)
|
||||||
|
completed_tokens += self.page_size
|
||||||
|
else:
|
||||||
|
# operation terminated by controller, release pre-allocated memory
|
||||||
|
self.mem_pool_host.free(
|
||||||
|
operation.host_indices[operation.completed_tokens :]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
def mooncake_page_transfer(self, operation):
|
||||||
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||||
|
operation.hash_value, operation.host_indices
|
||||||
|
)
|
||||||
|
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
||||||
|
operation.increment(len(operation.hash_value) * self.page_size)
|
||||||
|
|
||||||
def prefetch_io_aux_func(self):
|
def prefetch_io_aux_func(self):
|
||||||
"""
|
"""
|
||||||
Auxiliary function conducting IO operations for prefetching.
|
Auxiliary function conducting IO operations for prefetching.
|
||||||
@@ -539,26 +580,10 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||||
page_datas = self.storage_backend.batch_get(operation.hash_value)
|
if isinstance(self.storage_backend, MooncakeStore):
|
||||||
for h, page_data in zip(operation.hash_value, page_datas):
|
self.mooncake_page_transfer(operation)
|
||||||
if page_data is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
if operation.increment(self.page_size):
|
|
||||||
self.mem_pool_host.set_from_flat_data_page(
|
|
||||||
operation.host_indices[
|
|
||||||
operation.completed_tokens - self.page_size
|
|
||||||
],
|
|
||||||
page_data,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# operation terminated by controller, release pre-allocated memory
|
self.generic_page_transfer(operation)
|
||||||
self.mem_pool_host.free(
|
|
||||||
operation.host_indices[operation.completed_tokens :]
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -582,18 +607,27 @@ class HiCacheController:
|
|||||||
remaining_tokens = len(tokens_to_fetch)
|
remaining_tokens = len(tokens_to_fetch)
|
||||||
hash_value = []
|
hash_value = []
|
||||||
while remaining_tokens >= self.page_size:
|
while remaining_tokens >= self.page_size:
|
||||||
last_hash = get_hash_str(
|
last_hash = self.get_hash_str(
|
||||||
tokens_to_fetch[
|
tokens_to_fetch[
|
||||||
storage_hit_count : storage_hit_count + self.page_size
|
storage_hit_count : storage_hit_count + self.page_size
|
||||||
],
|
],
|
||||||
last_hash,
|
last_hash,
|
||||||
)
|
)
|
||||||
if self.storage_backend.exists(last_hash):
|
|
||||||
storage_hit_count += self.page_size
|
# todo, more unified interface
|
||||||
hash_value.append(last_hash)
|
if not isinstance(self.storage_backend, MooncakeStore):
|
||||||
remaining_tokens -= self.page_size
|
if not self.storage_backend.exists(last_hash):
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
|
hash_value.append(last_hash)
|
||||||
|
storage_hit_count += self.page_size
|
||||||
|
remaining_tokens -= self.page_size
|
||||||
|
|
||||||
|
if isinstance(self.storage_backend, MooncakeStore):
|
||||||
|
# deferring to batch exists for mooncake store
|
||||||
|
exist_result = self.storage_backend.exists(hash_value)
|
||||||
|
storage_hit_count = (
|
||||||
|
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
storage_hit_count_tensor = torch.tensor(
|
storage_hit_count_tensor = torch.tensor(
|
||||||
@@ -641,6 +675,47 @@ class HiCacheController:
|
|||||||
self.backup_queue.put(operation)
|
self.backup_queue.put(operation)
|
||||||
return operation.id
|
return operation.id
|
||||||
|
|
||||||
|
def generic_page_backup(self, operation, batch_size=8):
|
||||||
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
|
page_hashes = operation.hash_value[i : i + batch_size]
|
||||||
|
page_data = [
|
||||||
|
self.mem_pool_host.get_flat_data_pages(
|
||||||
|
operation.host_indices[j * self.page_size]
|
||||||
|
)
|
||||||
|
for j in range(i, i + len(page_hashes))
|
||||||
|
]
|
||||||
|
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||||
|
break
|
||||||
|
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||||
|
|
||||||
|
def mooncake_page_backup(self, operation):
|
||||||
|
if len(operation.hash_value):
|
||||||
|
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
||||||
|
indices = operation.host_indices.tolist()
|
||||||
|
non_exist_keys = []
|
||||||
|
non_exist_indices = []
|
||||||
|
for i in range(len(operation.hash_value)):
|
||||||
|
if not exist_hashvalues[operation.hash_value[i]]:
|
||||||
|
non_exist_keys.append(operation.hash_value[i])
|
||||||
|
non_exist_indices.extend(
|
||||||
|
indices[i * self.page_size : (i + 1) * self.page_size]
|
||||||
|
)
|
||||||
|
if len(non_exist_keys) > 0:
|
||||||
|
key_strs, buffer_ptrs, buffer_sizes = (
|
||||||
|
self.mem_pool_host.get_buffer_meta(
|
||||||
|
non_exist_keys, non_exist_indices
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# TODO: check the return value of batch set to see how many tokens are set successfully
|
||||||
|
self.storage_backend.batch_set(
|
||||||
|
key_strs,
|
||||||
|
target_location=buffer_ptrs,
|
||||||
|
target_sizes=buffer_sizes,
|
||||||
|
)
|
||||||
|
operation.completed_tokens += len(operation.hash_value) * self.page_size
|
||||||
|
|
||||||
def backup_thread_func(self):
|
def backup_thread_func(self):
|
||||||
"""
|
"""
|
||||||
Manage backup operations from host memory to storage backend.
|
Manage backup operations from host memory to storage backend.
|
||||||
@@ -654,23 +729,25 @@ class HiCacheController:
|
|||||||
last_hash = operation.last_hash
|
last_hash = operation.last_hash
|
||||||
tokens_to_backup = operation.token_ids
|
tokens_to_backup = operation.token_ids
|
||||||
|
|
||||||
last_hashes, data_pages = [], []
|
backup_hit_count = 0
|
||||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
remaining_tokens = len(tokens_to_backup)
|
||||||
last_hash = get_hash_str(
|
hash_value = []
|
||||||
tokens_to_backup[i : i + self.page_size], last_hash
|
while remaining_tokens >= self.page_size:
|
||||||
|
last_hash = self.get_hash_str(
|
||||||
|
tokens_to_backup[
|
||||||
|
backup_hit_count : backup_hit_count + self.page_size
|
||||||
|
],
|
||||||
|
last_hash,
|
||||||
)
|
)
|
||||||
data_page = self.mem_pool_host.get_flat_data_page(
|
backup_hit_count += self.page_size
|
||||||
operation.host_indices[i]
|
hash_value.append(last_hash)
|
||||||
)
|
remaining_tokens -= self.page_size
|
||||||
last_hashes.append(last_hash)
|
operation.hash_value = hash_value
|
||||||
data_pages.append(data_page)
|
|
||||||
|
|
||||||
success = self.storage_backend.batch_set(last_hashes, data_pages)
|
if isinstance(self.storage_backend, MooncakeStore):
|
||||||
if not success:
|
self.mooncake_page_backup(operation)
|
||||||
logger.warning(f"Failed to write page {last_hashes} to storage.")
|
|
||||||
else:
|
else:
|
||||||
operation.completed_tokens += len(tokens_to_backup)
|
self.generic_page_backup(operation)
|
||||||
operation.hash_value.extend(last_hashes)
|
|
||||||
|
|
||||||
min_completed_tokens = operation.completed_tokens
|
min_completed_tokens = operation.completed_tokens
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -39,7 +39,10 @@ class HiCacheStorage(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(
|
def get(
|
||||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
self,
|
||||||
|
key: str,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
"""
|
"""
|
||||||
Retrieve the value associated with the given key.
|
Retrieve the value associated with the given key.
|
||||||
@@ -49,7 +52,10 @@ class HiCacheStorage(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def batch_get(
|
def batch_get(
|
||||||
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None]:
|
||||||
"""
|
"""
|
||||||
Retrieve values for multiple keys.
|
Retrieve values for multiple keys.
|
||||||
@@ -58,7 +64,13 @@ class HiCacheStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set(self, key, value) -> bool:
|
def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Store the value associated with the given key.
|
Store the value associated with the given key.
|
||||||
Returns True if the operation was successful, False otherwise.
|
Returns True if the operation was successful, False otherwise.
|
||||||
@@ -66,7 +78,13 @@ class HiCacheStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
values: Optional[Any] = None,
|
||||||
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Store multiple key-value pairs.
|
Store multiple key-value pairs.
|
||||||
Returns True if all operations were successful, False otherwise.
|
Returns True if all operations were successful, False otherwise.
|
||||||
@@ -74,7 +92,7 @@ class HiCacheStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool | dict:
|
||||||
"""
|
"""
|
||||||
Check if the key exists in the storage.
|
Check if the key exists in the storage.
|
||||||
Returns True if the key exists, False otherwise.
|
Returns True if the key exists, False otherwise.
|
||||||
@@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
return key + self.tp_suffix
|
return key + self.tp_suffix
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
self,
|
||||||
|
key: str,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
key = self._get_suffixed_key(key)
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
@@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
def batch_get(
|
def batch_get(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_locations: Optional[List[torch.Tensor]] = None,
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None]:
|
||||||
return [
|
return [
|
||||||
self.get(key, target_location)
|
self.get(key, target_location)
|
||||||
@@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
key = self._get_suffixed_key(key)
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
if self.exists(key):
|
if self.exists(key):
|
||||||
@@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
logger.error(f"Failed to save tensor {key}: {e}")
|
logger.error(f"Failed to save tensor {key}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
values: Optional[Any] = None,
|
||||||
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
for key, value in zip(keys, values):
|
for key, value in zip(keys, values):
|
||||||
if not self.set(key, value):
|
if not self.set(key, value):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -594,6 +594,10 @@ class HiRadixCache(RadixCache):
|
|||||||
if child.backuped:
|
if child.backuped:
|
||||||
new_node.host_value = child.host_value[:split_len]
|
new_node.host_value = child.host_value[:split_len]
|
||||||
child.host_value = child.host_value[split_len:]
|
child.host_value = child.host_value[split_len:]
|
||||||
|
|
||||||
|
if child.hash_value:
|
||||||
|
new_node.hash_value = child.hash_value[: split_len // self.page_size]
|
||||||
|
child.hash_value = child.hash_value[split_len // self.page_size :]
|
||||||
child.parent = new_node
|
child.parent = new_node
|
||||||
child.key = child.key[split_len:]
|
child.key = child.key[split_len:]
|
||||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||||
|
|||||||
@@ -265,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_buffer_meta(self, keys, indices):
|
||||||
|
ptr_list = []
|
||||||
|
key_list = []
|
||||||
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||||
|
v_offset = (
|
||||||
|
self.layer_num
|
||||||
|
* self.size
|
||||||
|
* self.head_num
|
||||||
|
* self.head_dim
|
||||||
|
* self.dtype.itemsize
|
||||||
|
)
|
||||||
|
for index in range(0, len(indices), self.page_size):
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
k_ptr = (
|
||||||
|
kv_buffer_data_ptr
|
||||||
|
+ indices[index]
|
||||||
|
* self.head_num
|
||||||
|
* self.head_dim
|
||||||
|
* self.dtype.itemsize
|
||||||
|
+ layer_id
|
||||||
|
* self.size
|
||||||
|
* self.head_num
|
||||||
|
* self.head_dim
|
||||||
|
* self.dtype.itemsize
|
||||||
|
)
|
||||||
|
v_ptr = k_ptr + v_offset
|
||||||
|
ptr_list.append(k_ptr)
|
||||||
|
ptr_list.append(v_ptr)
|
||||||
|
key_ = keys[index // self.page_size]
|
||||||
|
key_list.append(f"{key_}_{layer_id}_k")
|
||||||
|
key_list.append(f"{key_}_{layer_id}_v")
|
||||||
|
element_size = (
|
||||||
|
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
||||||
|
)
|
||||||
|
element_size_list = [element_size] * len(key_list)
|
||||||
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def k_buffer(self):
|
def k_buffer(self):
|
||||||
return self.kv_buffer[0]
|
return self.kv_buffer[0]
|
||||||
@@ -325,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
1,
|
1,
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_buffer_meta(self, keys, indices):
|
||||||
|
ptr_list = []
|
||||||
|
key_list = []
|
||||||
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||||
|
for index in range(0, len(indices), self.page_size):
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
k_ptr = (
|
||||||
|
kv_buffer_data_ptr
|
||||||
|
+ indices[index]
|
||||||
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
* self.dtype.itemsize
|
||||||
|
+ layer_id
|
||||||
|
* self.size
|
||||||
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
* self.dtype.itemsize
|
||||||
|
)
|
||||||
|
ptr_list.append(k_ptr)
|
||||||
|
key_ = keys[index // self.page_size]
|
||||||
|
key_list.append(f"{key_}_{layer_id}_k")
|
||||||
|
element_size = (
|
||||||
|
self.dtype.itemsize
|
||||||
|
* self.page_size
|
||||||
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
)
|
||||||
|
element_size_list = [element_size] * len(key_list)
|
||||||
|
return key_list, ptr_list, element_size_list
|
||||||
|
|||||||
71
python/sglang/srt/mem_cache/mooncake_store/README.md
Normal file
71
python/sglang/srt/mem_cache/mooncake_store/README.md
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Mooncake as L3 KV Cache
|
||||||
|
|
||||||
|
This document describes how to use Mooncake as the L3 KV cache for SGLang.
|
||||||
|
For more details about Mooncake, please refer to: https://kvcache-ai.github.io/
|
||||||
|
|
||||||
|
## Install Mooncake
|
||||||
|
|
||||||
|
### Method 1: with pip
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mooncake-transfer-engine
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 2: from source
|
||||||
|
|
||||||
|
Clone Mooncake project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/kvcache-ai/Mooncake --recursive
|
||||||
|
```
|
||||||
|
|
||||||
|
Install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd Mooncake
|
||||||
|
bash dependencies.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the project. For additional build options, please refer to [the official guide](https://kvcache-ai.github.io/Mooncake/getting_started/build.html).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j
|
||||||
|
```
|
||||||
|
|
||||||
|
Install Mooncake:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo make install
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Mooncake
|
||||||
|
|
||||||
|
Launch Mooncake master server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mooncake_master
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch Mooncake meta server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m mooncake.http_metadata_server
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
|
||||||
|
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
|
||||||
|
MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \
|
||||||
|
MOONCAKE_PROTOCOL="rdma" \
|
||||||
|
MOONCAKE_DEVICE="erdma_0,erdma_1" \
|
||||||
|
MOONCAKE_MASTER=127.0.0.1:50051 \
|
||||||
|
python -m sglang.launch_server \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-storage-backend mooncake\
|
||||||
|
--model-path [model_path]
|
||||||
|
```
|
||||||
264
python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py
Normal file
264
python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||||
|
|
||||||
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
||||||
|
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
|
||||||
|
local_rank = get_tensor_model_parallel_rank()
|
||||||
|
prefix_str = ""
|
||||||
|
if prefix_block_key:
|
||||||
|
if len(prefix_block_key):
|
||||||
|
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
|
||||||
|
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
||||||
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
||||||
|
current_hash_hex = current_hash_object.hexdigest()
|
||||||
|
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MooncakeStoreConfig:
|
||||||
|
local_hostname: str
|
||||||
|
metadata_server: str
|
||||||
|
global_segment_size: int
|
||||||
|
local_buffer_size: int
|
||||||
|
protocol: str
|
||||||
|
device_name: str
|
||||||
|
master_server_address: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_file() -> "MooncakeStoreConfig":
|
||||||
|
"""Load the config from a JSON file."""
|
||||||
|
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||||
|
if file_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
||||||
|
)
|
||||||
|
with open(file_path) as fin:
|
||||||
|
config = json.load(fin)
|
||||||
|
return MooncakeStoreConfig(
|
||||||
|
local_hostname=config.get("local_hostname"),
|
||||||
|
metadata_server=config.get("metadata_server"),
|
||||||
|
global_segment_size=config.get(
|
||||||
|
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
||||||
|
),
|
||||||
|
local_buffer_size=config.get(
|
||||||
|
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
||||||
|
),
|
||||||
|
protocol=config.get("protocol", "tcp"),
|
||||||
|
device_name=config.get("device_name", "auto"),
|
||||||
|
master_server_address=config.get("master_server_address"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_from_env() -> "MooncakeStoreConfig":
|
||||||
|
"""Load config from a file specified in the environment variable.
|
||||||
|
export MOONCAKE_MASTER=10.13.3.232:50051
|
||||||
|
export MOONCAKE_PROTOCOL="rdma"
|
||||||
|
export MOONCAKE_DEVICE="auto"
|
||||||
|
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
|
||||||
|
"""
|
||||||
|
# other required environment variables...
|
||||||
|
if not os.getenv("MOONCAKE_MASTER"):
|
||||||
|
raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
|
||||||
|
return MooncakeStoreConfig(
|
||||||
|
local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
|
||||||
|
metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
|
||||||
|
global_segment_size=int(
|
||||||
|
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||||
|
),
|
||||||
|
local_buffer_size=int(
|
||||||
|
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
|
||||||
|
),
|
||||||
|
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
||||||
|
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
||||||
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.device_name == "auto":
|
||||||
|
os.environ["MC_MS_AUTO_DISC"] = "1"
|
||||||
|
os.environ["MC_MS_FILTERS"] = (
|
||||||
|
"mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MooncakeStore(HiCacheStorage):
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
from mooncake.store import MooncakeDistributedStore
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install mooncake by following the instructions at "
|
||||||
|
"https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
|
||||||
|
"to run SGLang with MooncakeConnector."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.store = MooncakeDistributedStore()
|
||||||
|
self.config = MooncakeStoreConfig.load_from_env()
|
||||||
|
logger.info("Mooncake Configuration loaded from env successfully.")
|
||||||
|
|
||||||
|
ret_code = self.store.setup(
|
||||||
|
self.config.local_hostname,
|
||||||
|
self.config.metadata_server,
|
||||||
|
self.config.global_segment_size,
|
||||||
|
self.config.local_buffer_size,
|
||||||
|
self.config.protocol,
|
||||||
|
self.config.device_name,
|
||||||
|
self.config.master_server_address,
|
||||||
|
)
|
||||||
|
if ret_code:
|
||||||
|
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
|
||||||
|
|
||||||
|
logger.info("Connect to Mooncake store successfully.")
|
||||||
|
self.warmup()
|
||||||
|
logger.info("Mooncake store warmup successfully.")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error("Configuration loading failed: %s", e)
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def warmup(self):
|
||||||
|
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
||||||
|
# 10 MB
|
||||||
|
warmup_value = bytes(10 * 1024 * 1024)
|
||||||
|
self.store.put(warmup_key, warmup_value)
|
||||||
|
assert self.store.is_exist(warmup_key) == 1
|
||||||
|
self.store.get(warmup_key)
|
||||||
|
self.store.remove(warmup_key)
|
||||||
|
|
||||||
|
def register_buffer(self, buffer: torch.Tensor) -> None:
|
||||||
|
try:
|
||||||
|
buffer_ptr = buffer.data_ptr()
|
||||||
|
buffer_size = buffer.numel() * buffer.element_size()
|
||||||
|
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
|
||||||
|
if ret_code:
|
||||||
|
logger.error(f"failed to register buffer, error code: {ret_code}")
|
||||||
|
except TypeError as err:
|
||||||
|
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
||||||
|
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key,
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[List[int]] = None,
|
||||||
|
target_sizes: Optional[List[int]] = None,
|
||||||
|
) -> bool:
|
||||||
|
assert len(key) == len(target_location) == len(target_sizes)
|
||||||
|
if len(key) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(len(key)):
|
||||||
|
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
||||||
|
|
||||||
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[List[int]] = None,
|
||||||
|
target_sizes: Optional[List[int]] = None,
|
||||||
|
) -> bool:
|
||||||
|
assert len(keys) == len(target_location) == len(target_sizes)
|
||||||
|
if len(keys) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
key,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
assert len(key) == len(target_location) == len(target_sizes)
|
||||||
|
if len(key) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(len(key)):
|
||||||
|
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
||||||
|
|
||||||
|
def batch_get(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
assert len(keys) == len(target_location) == len(target_sizes)
|
||||||
|
if len(keys) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
||||||
|
|
||||||
|
def exists(self, keys) -> bool | dict:
|
||||||
|
_keys = []
|
||||||
|
local_rank = torch.cuda.current_device()
|
||||||
|
for key in keys:
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
# Since mooncake store is stored in layer by layer,
|
||||||
|
# only the first layer is checked here.
|
||||||
|
_keys.append(f"{key}_{local_rank}_k")
|
||||||
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete(self, key) -> None:
|
||||||
|
raise (NotImplementedError)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# MooncakeDistributedStore will automatically call the destructor, so
|
||||||
|
# it is unnecessary to close it manually.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
raise (NotImplementedError)
|
||||||
|
|
||||||
|
def _put_batch_zero_copy_impl(
|
||||||
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
||||||
|
except TypeError as err:
|
||||||
|
logger.error("Failed to put value to Mooncake Store: %s", err)
|
||||||
|
raise TypeError("Mooncake Store Put Type Error.") from err
|
||||||
|
|
||||||
|
def _get_batch_zero_copy_impl(
|
||||||
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
||||||
|
except TypeError as err:
|
||||||
|
logger.error("Failed to get value from Mooncake Store: %s", err)
|
||||||
|
raise TypeError("Mooncake Store Get Type Error.") from err
|
||||||
40
python/sglang/srt/mem_cache/mooncake_store/unit_test.py
Normal file
40
python/sglang/srt/mem_cache/mooncake_store/unit_test.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from mooncake_store import MooncakeStore
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_and_warmup():
|
||||||
|
store = MooncakeStore()
|
||||||
|
assert store.store is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_buffer():
|
||||||
|
store = MooncakeStore()
|
||||||
|
tensor = torch.zeros(1024, dtype=torch.float32)
|
||||||
|
store.register_buffer(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_and_get():
|
||||||
|
store = MooncakeStore()
|
||||||
|
|
||||||
|
key = ["test_key_" + str(i) for i in range(2)]
|
||||||
|
tensor = torch.arange(256, dtype=torch.float32).cuda()
|
||||||
|
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
|
||||||
|
sizes = [tensor.numel() * tensor.element_size()] * 2
|
||||||
|
|
||||||
|
store.set(key, target_location=ptrs, target_sizes=sizes)
|
||||||
|
store.get(key, target_location=ptrs, target_sizes=sizes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_exists():
|
||||||
|
store = MooncakeStore()
|
||||||
|
keys = ["test_key_0", "non_existent_key"]
|
||||||
|
result = store.exists(keys)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "test_key_0" in result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_init_and_warmup()
|
||||||
|
test_register_buffer()
|
||||||
|
test_set_and_get()
|
||||||
|
test_exists()
|
||||||
@@ -1476,7 +1476,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-storage-backend",
|
"--hicache-storage-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["file", "hf3fs"], # todo, mooncake
|
choices=["file", "mooncake", "hf3fs"],
|
||||||
default=ServerArgs.hicache_storage_backend,
|
default=ServerArgs.hicache_storage_backend,
|
||||||
help="The storage backend for hierarchical KV cache.",
|
help="The storage backend for hierarchical KV cache.",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user