Support l3 cache (mooncake store) for hiradix cache (#7211)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
Co-authored-by: zuoyuan <zhangzuo21@mails.tsinghua.edu.cn>
Co-authored-by: @wangyueneng.wyn <wangyueneng.wyn@antgroup.com>
Co-authored-by: JinYan Su <jinyansu792@gmail.com>
This commit is contained in:
huangtingwei
2025-07-31 14:15:51 +08:00
committed by GitHub
parent 26c8a310bd
commit d904959233
8 changed files with 607 additions and 53 deletions

View File

@@ -2,7 +2,7 @@ import hashlib
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Any, List, Optional
import torch
@@ -39,7 +39,10 @@ class HiCacheStorage(ABC):
@abstractmethod
def get(
self, key: str, target_location: Optional[torch.Tensor] = None
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
"""
Retrieve the value associated with the given key.
@@ -49,7 +52,10 @@ class HiCacheStorage(ABC):
@abstractmethod
def batch_get(
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
"""
Retrieve values for multiple keys.
@@ -58,7 +64,13 @@ class HiCacheStorage(ABC):
pass
@abstractmethod
def set(self, key, value) -> bool:
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
"""
Store the value associated with the given key.
Returns True if the operation was successful, False otherwise.
@@ -66,7 +78,13 @@ class HiCacheStorage(ABC):
pass
@abstractmethod
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
"""
Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise.
@@ -74,7 +92,7 @@ class HiCacheStorage(ABC):
pass
@abstractmethod
def exists(self, key: str) -> bool:
def exists(self, key: str) -> bool | dict:
"""
Check if the key exists in the storage.
Returns True if the key exists, False otherwise.
@@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage):
return key + self.tp_suffix
def get(
self, key: str, target_location: Optional[torch.Tensor] = None
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin")
@@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage):
def batch_get(
self,
keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
return [
self.get(key, target_location)
@@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage):
)
]
def set(self, key: str, value: torch.Tensor) -> bool:
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin")
if self.exists(key):
@@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage):
logger.error(f"Failed to save tensor {key}: {e}")
return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
for key, value in zip(keys, values):
if not self.set(key, value):
return False

View File

@@ -594,6 +594,10 @@ class HiRadixCache(RadixCache):
if child.backuped:
new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:]
if child.hash_value:
new_node.hash_value = child.hash_value[: split_len // self.page_size]
child.hash_value = child.hash_value[split_len // self.page_size :]
child.parent = new_node
child.key = child.key[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node

View File

@@ -265,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache):
self.head_dim,
)
def get_buffer_meta(self, keys, indices):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
v_offset = (
self.layer_num
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.head_num
* self.head_dim
* self.dtype.itemsize
+ layer_id
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
key_list.append(f"{key_}_{layer_id}_v")
element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
@property
def k_buffer(self):
return self.kv_buffer[0]
@@ -325,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache):
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
def get_buffer_meta(self, keys, indices):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
element_size = (
self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list

View File

@@ -0,0 +1,71 @@
# Mooncake as L3 KV Cache
This document describes how to use Mooncake as the L3 KV cache for SGLang.
For more details about Mooncake, please refer to: https://kvcache-ai.github.io/
## Install Mooncake
### Method 1: with pip
```bash
pip install mooncake-transfer-engine
```
### Method 2: from source
Clone Mooncake project:
```bash
git clone https://github.com/kvcache-ai/Mooncake --recursive
```
Install dependencies:
```bash
cd Mooncake
bash dependencies.sh
```
Build the project. For additional build options, please refer to [the official guide](https://kvcache-ai.github.io/Mooncake/getting_started/build.html).
```bash
mkdir build
cd build
cmake ..
make -j
```
Install Mooncake:
```bash
sudo make install
```
## Use Mooncake
Launch Mooncake master server:
```bash
mooncake_master
```
Launch Mooncake meta server:
```bash
python -m mooncake.http_metadata_server
```
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables:
```bash
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \
MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="erdma_0,erdma_1" \
MOONCAKE_MASTER=127.0.0.1:50051 \
python -m sglang.launch_server \
--enable-hierarchical-cache \
--hicache-storage-backend mooncake\
--model-path [model_path]
```

View File

@@ -0,0 +1,264 @@
import hashlib
import json
import logging
import os
import uuid
from dataclasses import dataclass
from typing import Any, List, Optional
import numpy as np
import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
logger = logging.getLogger(__name__)
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
local_rank = get_tensor_model_parallel_rank()
prefix_str = ""
if prefix_block_key:
if len(prefix_block_key):
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
current_token_ids_bytes = np.array(current_page_ids).tobytes()
current_hash_object = hashlib.sha256(current_token_ids_bytes)
current_hash_hex = current_hash_object.hexdigest()
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file() -> "MooncakeStoreConfig":
"""Load the config from a JSON file."""
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size=config.get(
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", "auto"),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
"""Load config from a file specified in the environment variable.
export MOONCAKE_MASTER=10.13.3.232:50051
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="auto"
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
"""
# other required environment variables...
if not os.getenv("MOONCAKE_MASTER"):
raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
return MooncakeStoreConfig(
local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
global_segment_size=int(
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
),
local_buffer_size=int(
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
),
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
master_server_address=os.getenv("MOONCAKE_MASTER"),
)
def __post_init__(self):
if self.device_name == "auto":
os.environ["MC_MS_AUTO_DISC"] = "1"
os.environ["MC_MS_FILTERS"] = (
"mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
)
class MooncakeStore(HiCacheStorage):
def __init__(self):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
"to run SGLang with MooncakeConnector."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded from env successfully.")
ret_code = self.store.setup(
self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
if ret_code:
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
logger.info("Connect to Mooncake store successfully.")
self.warmup()
logger.info("Mooncake store warmup successfully.")
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def warmup(self):
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
# 10 MB
warmup_value = bytes(10 * 1024 * 1024)
self.store.put(warmup_key, warmup_value)
assert self.store.is_exist(warmup_key) == 1
self.store.get(warmup_key)
self.store.remove(warmup_key)
def register_buffer(self, buffer: torch.Tensor) -> None:
try:
buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size()
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
if ret_code:
logger.error(f"failed to register buffer, error code: {ret_code}")
except TypeError as err:
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
def set(
self,
key,
value: Optional[Any] = None,
target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
assert len(key) == len(target_location) == len(target_sizes)
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_set(
self,
keys: List[str],
value: Optional[Any] = None,
target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0:
return
for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
return
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
def get(
self,
key,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
assert len(key) == len(target_location) == len(target_sizes)
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_get(
self,
keys: List[str],
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0:
return
for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
return
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
def exists(self, keys) -> bool | dict:
_keys = []
local_rank = torch.cuda.current_device()
for key in keys:
if key is None:
return None
# Since mooncake store is stored in layer by layer,
# only the first layer is checked here.
_keys.append(f"{key}_{local_rank}_k")
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
return result
def delete(self, key) -> None:
raise (NotImplementedError)
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def clear(self) -> None:
raise (NotImplementedError)
def _put_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None:
try:
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err:
logger.error("Failed to put value to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None:
try:
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err

View File

@@ -0,0 +1,40 @@
import torch
from mooncake_store import MooncakeStore
def test_init_and_warmup():
store = MooncakeStore()
assert store.store is not None
def test_register_buffer():
store = MooncakeStore()
tensor = torch.zeros(1024, dtype=torch.float32)
store.register_buffer(tensor)
def test_set_and_get():
store = MooncakeStore()
key = ["test_key_" + str(i) for i in range(2)]
tensor = torch.arange(256, dtype=torch.float32).cuda()
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
sizes = [tensor.numel() * tensor.element_size()] * 2
store.set(key, target_location=ptrs, target_sizes=sizes)
store.get(key, target_location=ptrs, target_sizes=sizes)
def test_exists():
store = MooncakeStore()
keys = ["test_key_0", "non_existent_key"]
result = store.exists(keys)
assert isinstance(result, dict)
assert "test_key_0" in result
if __name__ == "__main__":
test_init_and_warmup()
test_register_buffer()
test_set_and_get()
test_exists()