From d904959233bb21f5bb713ac3da46da616160d3f3 Mon Sep 17 00:00:00 2001 From: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:15:51 +0800 Subject: [PATCH] Support l3 cache (mooncake store) for hiradix cache (#7211) Co-authored-by: Zhiqiang Xie Co-authored-by: AniZpZ Co-authored-by: zuoyuan Co-authored-by: @wangyueneng.wyn Co-authored-by: JinYan Su --- .../sglang/srt/managers/cache_controller.py | 161 ++++++++--- .../sglang/srt/mem_cache/hicache_storage.py | 54 +++- python/sglang/srt/mem_cache/hiradix_cache.py | 4 + .../sglang/srt/mem_cache/memory_pool_host.py | 64 +++++ .../srt/mem_cache/mooncake_store/README.md | 71 +++++ .../mooncake_store/mooncake_store.py | 264 ++++++++++++++++++ .../srt/mem_cache/mooncake_store/unit_test.py | 40 +++ python/sglang/srt/server_args.py | 2 +- 8 files changed, 607 insertions(+), 53 deletions(-) create mode 100644 python/sglang/srt/mem_cache/mooncake_store/README.md create mode 100644 python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py create mode 100644 python/sglang/srt/mem_cache/mooncake_store/unit_test.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 629e77748..a6e48961c 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -26,6 +26,10 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str +from sglang.srt.mem_cache.mooncake_store.mooncake_store import ( + MooncakeStore, + get_hash_str_mooncake, +) from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS logger = logging.getLogger(__name__) @@ -125,7 +129,7 @@ class TransferBuffer: """ def __init__( - self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000 + self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024 ) -> None: self.stop_event = stop_event self.buffers = Queue(maxsize=buffer_count) @@ -260,6 +264,11 @@ class HiCacheController: if storage_backend == "file": self.storage_backend = HiCacheFile() + self.get_hash_str = get_hash_str + elif storage_backend == "mooncake": + self.storage_backend = MooncakeStore() + self.get_hash_str = get_hash_str_mooncake + self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) elif storage_backend == "hf3fs": from sglang.srt.distributed import get_tensor_model_parallel_rank @@ -271,6 +280,7 @@ class HiCacheController: self.storage_backend = HiCacheHF3FS.from_env_config( rank, bytes_per_page, dtype ) + self.get_hash_str = get_hash_str else: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" @@ -532,6 +542,37 @@ class HiCacheController: operation.mark_done() return operation.completed_tokens, operation.hash_value + def generic_page_transfer(self, operation, batch_size=8): + for i in range(0, len(operation.hash_value), batch_size): + page_hashes = operation.hash_value[i : i + batch_size] + page_data = self.storage_backend.batch_get(page_hashes) + if page_data is None: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." + ) + break + completed_tokens = operation.completed_tokens + if operation.increment(self.page_size * len(page_hashes)): + for i in range(len(page_hashes)): + self.mem_pool_host.set_from_flat_data_page( + operation.host_indices[completed_tokens], + page_data[i], + ) + completed_tokens += self.page_size + else: + # operation terminated by controller, release pre-allocated memory + self.mem_pool_host.free( + operation.host_indices[operation.completed_tokens :] + ) + break + + def mooncake_page_transfer(self, operation): + key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( + operation.hash_value, operation.host_indices + ) + self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes) + operation.increment(len(operation.hash_value) * self.page_size) + def prefetch_io_aux_func(self): """ Auxiliary function conducting IO operations for prefetching. @@ -539,26 +580,10 @@ class HiCacheController: while not self.stop_event.is_set(): try: operation = self.prefetch_buffer.get(block=True, timeout=1) - page_datas = self.storage_backend.batch_get(operation.hash_value) - for h, page_data in zip(operation.hash_value, page_datas): - if page_data is None: - logger.warning( - f"Prefetch operation {operation.request_id} failed to retrieve page {h}." - ) - break - if operation.increment(self.page_size): - self.mem_pool_host.set_from_flat_data_page( - operation.host_indices[ - operation.completed_tokens - self.page_size - ], - page_data, - ) - else: - # operation terminated by controller, release pre-allocated memory - self.mem_pool_host.free( - operation.host_indices[operation.completed_tokens :] - ) - break + if isinstance(self.storage_backend, MooncakeStore): + self.mooncake_page_transfer(operation) + else: + self.generic_page_transfer(operation) except Empty: continue @@ -582,18 +607,27 @@ class HiCacheController: remaining_tokens = len(tokens_to_fetch) hash_value = [] while remaining_tokens >= self.page_size: - last_hash = get_hash_str( + last_hash = self.get_hash_str( tokens_to_fetch[ storage_hit_count : storage_hit_count + self.page_size ], last_hash, ) - if self.storage_backend.exists(last_hash): - storage_hit_count += self.page_size - hash_value.append(last_hash) - remaining_tokens -= self.page_size - else: - break + + # todo, more unified interface + if not isinstance(self.storage_backend, MooncakeStore): + if not self.storage_backend.exists(last_hash): + break + hash_value.append(last_hash) + storage_hit_count += self.page_size + remaining_tokens -= self.page_size + + if isinstance(self.storage_backend, MooncakeStore): + # deferring to batch exists for mooncake store + exist_result = self.storage_backend.exists(hash_value) + storage_hit_count = ( + sum(1 for v in exist_result.values() if v != 0) * self.page_size + ) if self.tp_world_size > 1: storage_hit_count_tensor = torch.tensor( @@ -641,6 +675,47 @@ class HiCacheController: self.backup_queue.put(operation) return operation.id + def generic_page_backup(self, operation, batch_size=8): + for i in range(0, len(operation.hash_value), batch_size): + page_hashes = operation.hash_value[i : i + batch_size] + page_data = [ + self.mem_pool_host.get_flat_data_pages( + operation.host_indices[j * self.page_size] + ) + for j in range(i, i + len(page_hashes)) + ] + success = self.storage_backend.batch_set(page_hashes, page_data) + if not success: + logger.warning(f"Failed to write page {page_hashes} to storage.") + break + operation.completed_tokens += self.page_size * len(page_hashes) + + def mooncake_page_backup(self, operation): + if len(operation.hash_value): + exist_hashvalues = self.storage_backend.exists(operation.hash_value) + indices = operation.host_indices.tolist() + non_exist_keys = [] + non_exist_indices = [] + for i in range(len(operation.hash_value)): + if not exist_hashvalues[operation.hash_value[i]]: + non_exist_keys.append(operation.hash_value[i]) + non_exist_indices.extend( + indices[i * self.page_size : (i + 1) * self.page_size] + ) + if len(non_exist_keys) > 0: + key_strs, buffer_ptrs, buffer_sizes = ( + self.mem_pool_host.get_buffer_meta( + non_exist_keys, non_exist_indices + ) + ) + # TODO: check the return value of batch set to see how many tokens are set successfully + self.storage_backend.batch_set( + key_strs, + target_location=buffer_ptrs, + target_sizes=buffer_sizes, + ) + operation.completed_tokens += len(operation.hash_value) * self.page_size + def backup_thread_func(self): """ Manage backup operations from host memory to storage backend. @@ -654,23 +729,25 @@ class HiCacheController: last_hash = operation.last_hash tokens_to_backup = operation.token_ids - last_hashes, data_pages = [], [] - for i in range(0, len(tokens_to_backup), self.page_size): - last_hash = get_hash_str( - tokens_to_backup[i : i + self.page_size], last_hash + backup_hit_count = 0 + remaining_tokens = len(tokens_to_backup) + hash_value = [] + while remaining_tokens >= self.page_size: + last_hash = self.get_hash_str( + tokens_to_backup[ + backup_hit_count : backup_hit_count + self.page_size + ], + last_hash, ) - data_page = self.mem_pool_host.get_flat_data_page( - operation.host_indices[i] - ) - last_hashes.append(last_hash) - data_pages.append(data_page) + backup_hit_count += self.page_size + hash_value.append(last_hash) + remaining_tokens -= self.page_size + operation.hash_value = hash_value - success = self.storage_backend.batch_set(last_hashes, data_pages) - if not success: - logger.warning(f"Failed to write page {last_hashes} to storage.") + if isinstance(self.storage_backend, MooncakeStore): + self.mooncake_page_backup(operation) else: - operation.completed_tokens += len(tokens_to_backup) - operation.hash_value.extend(last_hashes) + self.generic_page_backup(operation) min_completed_tokens = operation.completed_tokens if self.tp_world_size > 1: diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 0e4a7184c..d0dec8ac9 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -2,7 +2,7 @@ import hashlib import logging import os from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Any, List, Optional import torch @@ -39,7 +39,10 @@ class HiCacheStorage(ABC): @abstractmethod def get( - self, key: str, target_location: Optional[torch.Tensor] = None + self, + key: str, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> torch.Tensor | None: """ Retrieve the value associated with the given key. @@ -49,7 +52,10 @@ class HiCacheStorage(ABC): @abstractmethod def batch_get( - self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None + self, + keys: List[str], + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> List[torch.Tensor | None]: """ Retrieve values for multiple keys. @@ -58,7 +64,13 @@ class HiCacheStorage(ABC): pass @abstractmethod - def set(self, key, value) -> bool: + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: """ Store the value associated with the given key. Returns True if the operation was successful, False otherwise. @@ -66,7 +78,13 @@ class HiCacheStorage(ABC): pass @abstractmethod - def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: """ Store multiple key-value pairs. Returns True if all operations were successful, False otherwise. @@ -74,7 +92,7 @@ class HiCacheStorage(ABC): pass @abstractmethod - def exists(self, key: str) -> bool: + def exists(self, key: str) -> bool | dict: """ Check if the key exists in the storage. Returns True if the key exists, False otherwise. @@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage): return key + self.tp_suffix def get( - self, key: str, target_location: Optional[torch.Tensor] = None + self, + key: str, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> torch.Tensor | None: key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") @@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage): def batch_get( self, keys: List[str], - target_locations: Optional[List[torch.Tensor]] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> List[torch.Tensor | None]: return [ self.get(key, target_location) @@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage): ) ] - def set(self, key: str, value: torch.Tensor) -> bool: + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") if self.exists(key): @@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage): logger.error(f"Failed to save tensor {key}: {e}") return False - def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: for key, value in zip(keys, values): if not self.set(key, value): return False diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index ef61101d7..681985ad1 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -594,6 +594,10 @@ class HiRadixCache(RadixCache): if child.backuped: new_node.host_value = child.host_value[:split_len] child.host_value = child.host_value[split_len:] + + if child.hash_value: + new_node.hash_value = child.hash_value[: split_len // self.page_size] + child.hash_value = child.hash_value[split_len // self.page_size :] child.parent = new_node child.key = child.key[split_len:] new_node.parent.children[self.get_child_key_fn(key)] = new_node diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index c2fb4fa46..4202db801 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -265,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache): self.head_dim, ) + def get_buffer_meta(self, keys, indices): + ptr_list = [] + key_list = [] + kv_buffer_data_ptr = self.kv_buffer.data_ptr() + v_offset = ( + self.layer_num + * self.size + * self.head_num + * self.head_dim + * self.dtype.itemsize + ) + for index in range(0, len(indices), self.page_size): + for layer_id in range(self.layer_num): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.head_num + * self.head_dim + * self.dtype.itemsize + + layer_id + * self.size + * self.head_num + * self.head_dim + * self.dtype.itemsize + ) + v_ptr = k_ptr + v_offset + ptr_list.append(k_ptr) + ptr_list.append(v_ptr) + key_ = keys[index // self.page_size] + key_list.append(f"{key_}_{layer_id}_k") + key_list.append(f"{key_}_{layer_id}_v") + element_size = ( + self.dtype.itemsize * self.page_size * self.head_num * self.head_dim + ) + element_size_list = [element_size] * len(key_list) + return key_list, ptr_list, element_size_list + @property def k_buffer(self): return self.kv_buffer[0] @@ -325,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache): 1, self.kv_lora_rank + self.qk_rope_head_dim, ) + + def get_buffer_meta(self, keys, indices): + ptr_list = [] + key_list = [] + kv_buffer_data_ptr = self.kv_buffer.data_ptr() + for index in range(0, len(indices), self.page_size): + for layer_id in range(self.layer_num): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + + layer_id + * self.size + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + ) + ptr_list.append(k_ptr) + key_ = keys[index // self.page_size] + key_list.append(f"{key_}_{layer_id}_k") + element_size = ( + self.dtype.itemsize + * self.page_size + * (self.kv_lora_rank + self.qk_rope_head_dim) + ) + element_size_list = [element_size] * len(key_list) + return key_list, ptr_list, element_size_list diff --git a/python/sglang/srt/mem_cache/mooncake_store/README.md b/python/sglang/srt/mem_cache/mooncake_store/README.md new file mode 100644 index 000000000..6ad71821e --- /dev/null +++ b/python/sglang/srt/mem_cache/mooncake_store/README.md @@ -0,0 +1,71 @@ +# Mooncake as L3 KV Cache + +This document describes how to use Mooncake as the L3 KV cache for SGLang. +For more details about Mooncake, please refer to: https://kvcache-ai.github.io/ + +## Install Mooncake + +### Method 1: with pip + +```bash +pip install mooncake-transfer-engine +``` + +### Method 2: from source + +Clone Mooncake project: + +```bash +git clone https://github.com/kvcache-ai/Mooncake --recursive +``` + +Install dependencies: + +```bash +cd Mooncake +bash dependencies.sh +``` + +Build the project. For additional build options, please refer to [the official guide](https://kvcache-ai.github.io/Mooncake/getting_started/build.html). + +```bash +mkdir build +cd build +cmake .. +make -j +``` + +Install Mooncake: + +```bash +sudo make install +``` + +## Use Mooncake + +Launch Mooncake master server: + +```bash +mooncake_master +``` + +Launch Mooncake meta server: + +```bash +python -m mooncake.http_metadata_server +``` + +Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables: + +```bash +MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \ +MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \ +MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \ +MOONCAKE_PROTOCOL="rdma" \ +MOONCAKE_DEVICE="erdma_0,erdma_1" \ +MOONCAKE_MASTER=127.0.0.1:50051 \ +python -m sglang.launch_server \ + --enable-hierarchical-cache \ + --hicache-storage-backend mooncake\ + --model-path [model_path] +``` diff --git a/python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py new file mode 100644 index 000000000..05dc7a3ce --- /dev/null +++ b/python/sglang/srt/mem_cache/mooncake_store/mooncake_store.py @@ -0,0 +1,264 @@ +import hashlib +import json +import logging +import os +import uuid +from dataclasses import dataclass +from typing import Any, List, Optional + +import numpy as np +import torch + +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage + +DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB + +logger = logging.getLogger(__name__) + + +def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str): + local_rank = get_tensor_model_parallel_rank() + prefix_str = "" + if prefix_block_key: + if len(prefix_block_key): + prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest() + current_token_ids_bytes = np.array(current_page_ids).tobytes() + current_hash_object = hashlib.sha256(current_token_ids_bytes) + current_hash_hex = current_hash_object.hexdigest() + return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}" + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file() -> "MooncakeStoreConfig": + """Load the config from a JSON file.""" + file_path = os.getenv("MOONCAKE_CONFIG_PATH") + if file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get( + "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE + ), + local_buffer_size=config.get( + "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE + ), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", "auto"), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + """Load config from a file specified in the environment variable. + export MOONCAKE_MASTER=10.13.3.232:50051 + export MOONCAKE_PROTOCOL="rdma" + export MOONCAKE_DEVICE="auto" + export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE" + """ + # other required environment variables... + if not os.getenv("MOONCAKE_MASTER"): + raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.") + return MooncakeStoreConfig( + local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"), + metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"), + global_segment_size=int( + os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE) + ), + local_buffer_size=int( + os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE) + ), + protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), + device_name=os.getenv("MOONCAKE_DEVICE", "auto"), + master_server_address=os.getenv("MOONCAKE_MASTER"), + ) + + def __post_init__(self): + if self.device_name == "auto": + os.environ["MC_MS_AUTO_DISC"] = "1" + os.environ["MC_MS_FILTERS"] = ( + "mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3" + ) + + +class MooncakeStore(HiCacheStorage): + def __init__(self): + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://kvcache-ai.github.io/Mooncake/getting_started/build.html" + "to run SGLang with MooncakeConnector." + ) from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded from env successfully.") + + ret_code = self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) + if ret_code: + logger.error(f"failed to setup mooncake store, error code: {ret_code}") + + logger.info("Connect to Mooncake store successfully.") + self.warmup() + logger.info("Mooncake store warmup successfully.") + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error("An error occurred while loading the configuration: %s", exc) + raise + + def warmup(self): + warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex + # 10 MB + warmup_value = bytes(10 * 1024 * 1024) + self.store.put(warmup_key, warmup_value) + assert self.store.is_exist(warmup_key) == 1 + self.store.get(warmup_key) + self.store.remove(warmup_key) + + def register_buffer(self, buffer: torch.Tensor) -> None: + try: + buffer_ptr = buffer.data_ptr() + buffer_size = buffer.numel() * buffer.element_size() + ret_code = self.store.register_buffer(buffer_ptr, buffer_size) + if ret_code: + logger.error(f"failed to register buffer, error code: {ret_code}") + except TypeError as err: + logger.error("Failed to register buffer to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Register Buffer Error.") from err + + def set( + self, + key, + value: Optional[Any] = None, + target_location: Optional[List[int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> bool: + assert len(key) == len(target_location) == len(target_sizes) + if len(key) == 0: + return + + for i in range(len(key)): + if key[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + self._put_batch_zero_copy_impl(key, target_location, target_sizes) + + def batch_set( + self, + keys: List[str], + value: Optional[Any] = None, + target_location: Optional[List[int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> bool: + assert len(keys) == len(target_location) == len(target_sizes) + if len(keys) == 0: + return + + for i in range(len(keys)): + if keys[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + self._put_batch_zero_copy_impl(keys, target_location, target_sizes) + + def get( + self, + key, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> torch.Tensor | None: + assert len(key) == len(target_location) == len(target_sizes) + if len(key) == 0: + return + + for i in range(len(key)): + if key[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + return self._get_batch_zero_copy_impl(key, target_location, target_sizes) + + def batch_get( + self, + keys: List[str], + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> torch.Tensor | None: + assert len(keys) == len(target_location) == len(target_sizes) + if len(keys) == 0: + return + + for i in range(len(keys)): + if keys[i] is None or target_location[i] is None or target_sizes[i] is None: + return + + return self._get_batch_zero_copy_impl(keys, target_location, target_sizes) + + def exists(self, keys) -> bool | dict: + _keys = [] + local_rank = torch.cuda.current_device() + for key in keys: + if key is None: + return None + # Since mooncake store is stored in layer by layer, + # only the first layer is checked here. + _keys.append(f"{key}_{local_rank}_k") + result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} + return result + + def delete(self, key) -> None: + raise (NotImplementedError) + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def clear(self) -> None: + raise (NotImplementedError) + + def _put_batch_zero_copy_impl( + self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] + ) -> None: + try: + self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes) + except TypeError as err: + logger.error("Failed to put value to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_batch_zero_copy_impl( + self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] + ) -> None: + try: + self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err diff --git a/python/sglang/srt/mem_cache/mooncake_store/unit_test.py b/python/sglang/srt/mem_cache/mooncake_store/unit_test.py new file mode 100644 index 000000000..801b0ec1b --- /dev/null +++ b/python/sglang/srt/mem_cache/mooncake_store/unit_test.py @@ -0,0 +1,40 @@ +import torch +from mooncake_store import MooncakeStore + + +def test_init_and_warmup(): + store = MooncakeStore() + assert store.store is not None + + +def test_register_buffer(): + store = MooncakeStore() + tensor = torch.zeros(1024, dtype=torch.float32) + store.register_buffer(tensor) + + +def test_set_and_get(): + store = MooncakeStore() + + key = ["test_key_" + str(i) for i in range(2)] + tensor = torch.arange(256, dtype=torch.float32).cuda() + ptrs = [tensor.data_ptr(), tensor.data_ptr()] + sizes = [tensor.numel() * tensor.element_size()] * 2 + + store.set(key, target_location=ptrs, target_sizes=sizes) + store.get(key, target_location=ptrs, target_sizes=sizes) + + +def test_exists(): + store = MooncakeStore() + keys = ["test_key_0", "non_existent_key"] + result = store.exists(keys) + assert isinstance(result, dict) + assert "test_key_0" in result + + +if __name__ == "__main__": + test_init_and_warmup() + test_register_buffer() + test_set_and_get() + test_exists() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d53558211..992905437 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1476,7 +1476,7 @@ class ServerArgs: parser.add_argument( "--hicache-storage-backend", type=str, - choices=["file", "hf3fs"], # todo, mooncake + choices=["file", "mooncake", "hf3fs"], default=ServerArgs.hicache_storage_backend, help="The storage backend for hierarchical KV cache.", )