diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 384dceb31..6c96c80a3 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -289,8 +289,6 @@ class HiCacheController: ) self.storage_backend = MooncakeStore(self.storage_config) - self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) - assert self.mem_pool_host.layout == "page_first" elif storage_backend == "hf3fs": from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( HiCacheHF3FS, @@ -313,6 +311,8 @@ class HiCacheController: f"Unsupported storage backend: {storage_backend}" ) + self.storage_backend.register_mem_pool_host(self.mem_pool_host) + self.enable_storage = True # todo: threshold policy for prefetching self.prefetch_threshold = max(prefetch_threshold, self.page_size) @@ -335,18 +335,10 @@ class HiCacheController: # Select the get and set functions self.page_get_func = self._generic_page_get self.page_set_func = self._generic_page_set - self.batch_exists_func = self.storage_backend.batch_exists - self.is_3fs_zerocopy = ( - self.storage_backend_type == "hf3fs" - and self.mem_pool_host.layout == "page_first" - ) - if self.storage_backend_type == "mooncake": - self.page_get_func = self._mooncake_page_get - self.page_set_func = self._mooncake_page_set - elif self.is_3fs_zerocopy: - self.page_get_func = self._3fs_zero_copy_page_get - self.page_set_func = self._3fs_zero_copy_page_set - self.batch_exists_func = self._3fs_zero_copy_batch_exists + + if self.storage_backend_type in ["hf3fs", "mooncake"]: + self.page_get_func = self._page_get_zero_copy + self.page_set_func = self._page_set_zero_copy self.device = self.mem_pool_device.device self.layer_num = self.mem_pool_device.layer_num @@ -630,42 +622,19 @@ class HiCacheController: for chunk in chunks: self.host_mem_release_queue.put(chunk) - def _3fs_zero_copy_batch_exists(self, batch_hashes): - _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes) - hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor - return hit_page_num - - def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): - hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash( - hash_values, host_indices - ) - page_data = self.storage_backend.batch_get(hashes, dsts) - if page_data: - inc = self.page_size * len(hashes) // factor - operation.increment(inc) - else: - logger.warning( - f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." - ) - - def _mooncake_page_get(self, operation, hash_values, host_indices): - key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( - hash_values, - host_indices, - self.storage_config.tp_rank, - ) - get_result = self.storage_backend.batch_get( - key_strs, - target_locations=buffer_ptrs, - target_sizes=buffer_sizes, - ) - if get_result != len(hash_values): - logger.warning( - f"Prefetch operation {operation.request_id} failed or partially failed." - ) - if get_result != 0: - operation.increment(get_result * self.page_size) + def _page_get_zero_copy(self, operation, hash_values, host_indices): + results = self.storage_backend.batch_get_v1(hash_values, host_indices) + inc = 0 + for i in range(len(hash_values)): + if not results[i]: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}." + ) + break + inc += self.page_size + operation.increment(inc) + # todo: deprecate def _generic_page_get(self, operation, hash_values, host_indices): dummy_page_dst = [ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values @@ -755,7 +724,7 @@ class HiCacheController: batch_tokens[i : i + self.page_size], last_hash ) batch_hashes.append(last_hash) - hit_page_num = self.batch_exists_func(batch_hashes) + hit_page_num = self.storage_backend.batch_exists(batch_hashes) hash_value.extend(batch_hashes[:hit_page_num]) storage_query_count += hit_page_num * self.page_size if hit_page_num < len(batch_hashes): @@ -824,34 +793,16 @@ class HiCacheController: self.backup_queue.put(operation) return operation.id - # non-zero copy + # todo: deprecate def _generic_page_set(self, hash_values, host_indices) -> bool: data = [ - self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size]) + self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) for i in range(len(hash_values)) ] return self.storage_backend.batch_set(hash_values, data) - # zero copy - def _mooncake_page_set(self, hash_values, host_indices) -> bool: - key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( - hash_values, - host_indices, - self.storage_config.tp_rank, - ) - success = self.storage_backend.batch_set( - key_strs, - target_locations=buffer_ptrs, - target_sizes=buffer_sizes, - ) - return success - - # zero copy - def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool: - hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash( - hash_values, host_indices - ) - return self.storage_backend.batch_set(hashes, dsts) + def _page_set_zero_copy(self, hash_values, host_indices) -> bool: + return all(self.storage_backend.batch_set_v1(hash_values, host_indices)) # Backup batch by batch def _page_backup(self, operation): diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 6ec077db5..8b21446b9 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -7,6 +7,8 @@ from typing import Any, List, Optional import torch +from sglang.srt.mem_cache.memory_pool_host import HostKVCache + logger = logging.getLogger(__name__) @@ -32,15 +34,46 @@ class HiCacheStorageConfig: extra_config: Optional[dict] = None +@dataclass +class HiCacheStorageExtraInfo: + extra_info: Optional[dict] = None + + class HiCacheStorage(ABC): """ HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. It abstracts the underlying storage mechanism, allowing different implementations to be used. """ - # todo, potentially pass model and TP configs into storage backend # todo, the page size of storage backend does not have to be the same as the same as host memory pool + def register_mem_pool_host(self, mem_pool_host: HostKVCache): + self.mem_pool_host = mem_pool_host + + def batch_get_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + """ + Retrieve values for multiple keys. + Returns a list of tensors or None for each key. + """ + pass + + def batch_set_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + """ + Retrieve values for multiple keys. + Returns a list of tensors or None for each key. + """ + pass + @abstractmethod def get( self, @@ -54,6 +87,7 @@ class HiCacheStorage(ABC): """ pass + # TODO: Deprecate @abstractmethod def batch_get( self, @@ -81,6 +115,7 @@ class HiCacheStorage(ABC): """ pass + # TODO: Deprecate @abstractmethod def batch_set( self, @@ -103,6 +138,7 @@ class HiCacheStorage(ABC): """ pass + # TODO: Use a finer-grained return type (e.g., List[bool]) def batch_exists(self, keys: List[str]) -> int: """ Check if the keys exist in the storage. @@ -114,6 +150,9 @@ class HiCacheStorage(ABC): return i return len(keys) + def clear(self) -> None: + pass + def get_stats(self): return None diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 079dc0a64..ab7538465 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -140,7 +140,7 @@ class HostKVCache(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def get_flat_data_page(self, index) -> torch.Tensor: + def get_data_page(self, index, flat: bool = True) -> torch.Tensor: """ Get a flat data page from the host memory pool. """ @@ -461,16 +461,19 @@ class MHATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported IO backend: {io_backend}") - def get_flat_data_page(self, index) -> torch.Tensor: + def get_data_page(self, index, flat: bool = True) -> torch.Tensor: if self.layout == "layer_first": - return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() + data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :] elif self.layout == "page_first": - return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten() + data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :] elif self.layout == "page_first_direct": real_index = index // self.page_size - return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten() + data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] else: raise ValueError(f"Unsupported layout: {self.layout}") + if flat: + data_page = data_page.flatten() + return data_page def get_dummy_flat_data_page(self) -> torch.Tensor: return torch.zeros( @@ -507,9 +510,12 @@ class MHATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") - def get_buffer_meta(self, keys, indices, local_rank): + def get_page_buffer_meta(self, indices): + """ " + meta data for zero copy + """ + assert len(indices) % self.page_size == 0 ptr_list = [] - key_list = [] kv_buffer_data_ptr = self.kv_buffer.data_ptr() indices = indices.tolist() v_offset = ( @@ -519,48 +525,52 @@ class MHATokenToKVPoolHost(HostKVCache): * self.head_dim * self.dtype.itemsize ) - for index in range(0, len(indices), self.page_size): - k_ptr = ( - kv_buffer_data_ptr - + indices[index] - * self.layer_num + if self.layout == "layer_first": + for index in range(0, len(indices), self.page_size): + for layer_id in range(self.layer_num): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.head_num + * self.head_dim + * self.dtype.itemsize + + layer_id + * self.size + * self.head_num + * self.head_dim + * self.dtype.itemsize + ) + v_ptr = k_ptr + v_offset + ptr_list.append(k_ptr) + ptr_list.append(v_ptr) + element_size = ( + self.dtype.itemsize * self.page_size * self.head_num * self.head_dim + ) + element_size_list = [element_size] * len(ptr_list) + elif self.layout in ["page_first", "page_first_direct"]: + for index in range(0, len(indices), self.page_size): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.layer_num + * self.head_num + * self.head_dim + * self.dtype.itemsize + ) + v_ptr = k_ptr + v_offset + ptr_list.append(k_ptr) + ptr_list.append(v_ptr) + element_size = ( + self.layer_num + * self.dtype.itemsize + * self.page_size * self.head_num * self.head_dim - * self.dtype.itemsize ) - v_ptr = k_ptr + v_offset - ptr_list.append(k_ptr) - ptr_list.append(v_ptr) - key_ = keys[index // self.page_size] - key_list.append(f"{key_}_{local_rank}_k") - key_list.append(f"{key_}_{local_rank}_v") - element_size = ( - self.layer_num - * self.dtype.itemsize - * self.page_size - * self.head_num - * self.head_dim - ) - element_size_list = [element_size] * len(key_list) - return key_list, ptr_list, element_size_list - - def get_buffer_with_hash(self, keys, indices=None): - assert self.layout == "page_first" - assert indices is None or (len(keys) == (len(indices) // self.page_size)) - - key_list = [] - buf_list = [] - - for i in range(len(keys)): - key = keys[i] - key_list.append(f"{key}-k") - key_list.append(f"{key}-v") - if indices is not None: - index = indices[i * self.page_size] - buf_list.append(self.k_buffer[index : index + self.page_size]) - buf_list.append(self.v_buffer[index : index + self.page_size]) - - return key_list, buf_list, 2 + element_size_list = [element_size] * len(ptr_list) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + return ptr_list, element_size_list class MLATokenToKVPoolHost(HostKVCache): @@ -736,16 +746,19 @@ class MLATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported IO backend: {io_backend}") - def get_flat_data_page(self, index) -> torch.Tensor: + def get_data_page(self, index, flat: bool = True) -> torch.Tensor: if self.layout == "layer_first": - return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() + data_page = self.kv_buffer[:, index : index + self.page_size, :, :] elif self.layout == "page_first": - return self.kv_buffer[index : index + self.page_size, :, :, :].flatten() + data_page = self.kv_buffer[index : index + self.page_size, :, :, :] elif self.layout == "page_first_direct": real_index = index // self.page_size - return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten() + data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :] else: raise ValueError(f"Unsupported layout: {self.layout}") + if flat: + data_page = data_page.flatten() + return data_page def get_dummy_flat_data_page(self) -> torch.Tensor: return torch.zeros( @@ -787,40 +800,51 @@ class MLATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") - def get_buffer_meta(self, keys, indices, local_rank): + def get_page_buffer_meta(self, indices): + """ " + meta data for zero copy + """ + assert len(indices) % self.page_size == 0 ptr_list = [] - key_list = [] kv_buffer_data_ptr = self.kv_buffer.data_ptr() indices = indices.tolist() - for index in range(0, len(indices), self.page_size): - k_ptr = ( - kv_buffer_data_ptr - + indices[index] - * self.layer_num + if self.layout == "layer_first": + for index in range(0, len(indices), self.page_size): + for layer_id in range(self.layer_num): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + + layer_id + * self.size + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + ) + ptr_list.append(k_ptr) + element_size = ( + self.dtype.itemsize + * self.page_size * (self.kv_lora_rank + self.qk_rope_head_dim) - * self.dtype.itemsize ) - ptr_list.append(k_ptr) - key_ = keys[index // self.page_size] - key_list.append(f"{key_}_k") - element_size = ( - self.layer_num - * self.dtype.itemsize - * self.page_size - * (self.kv_lora_rank + self.qk_rope_head_dim) - ) - element_size_list = [element_size] * len(key_list) - return key_list, ptr_list, element_size_list - - def get_buffer_with_hash(self, keys, indices=None): - assert self.layout == "page_first" - assert indices is None or (len(keys) == (len(indices) // self.page_size)) - - buf_list = [] - - if indices is not None: - for i in range(len(keys)): - index = indices[i * self.page_size] - buf_list.append(self.kv_buffer[index : index + self.page_size]) - - return keys, buf_list, 1 + element_size_list = [element_size] * len(ptr_list) + elif self.layout in ["page_first", "page_first_direct"]: + for index in range(0, len(indices), self.page_size): + k_ptr = ( + kv_buffer_data_ptr + + indices[index] + * self.layer_num + * (self.kv_lora_rank + self.qk_rope_head_dim) + * self.dtype.itemsize + ) + ptr_list.append(k_ptr) + element_size = ( + self.layer_num + * self.dtype.itemsize + * self.page_size + * (self.kv_lora_rank + self.qk_rope_head_dim) + ) + element_size_list = [element_size] * len(ptr_list) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + return ptr_list, element_size_list diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 9595e7204..2a159e493 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple import torch -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorage, + HiCacheStorageConfig, + HiCacheStorageExtraInfo, +) +from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient from sglang.srt.metrics.collector import StorageMetrics @@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage): self.skip_backup = True self.rank = 0 + self.is_zero_copy = False + logger.info( f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: " f"file_path={self.file_path}, " f"file_size={self.file_size / (2 ** 30):.2f} GB, " - f"num_pages={self.num_pages}" + f"num_pages={self.num_pages}, " + f"is_mla_model={self.is_mla_model}" ) self.ac = AtomicCounter(self.numjobs) @@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage): use_mock_client=use_mock_client, ) - def get( - self, - key: str, - target_location: Optional[Any] = None, - target_sizes: Optional[Any] = None, - ) -> torch.Tensor | None: - return self.batch_get( - [key], - [target_location] if target_location is not None else None, - [target_sizes] if target_sizes is not None else None, - )[0] - @synchronized() - def batch_get( + def _batch_get( self, keys: List[str], - target_locations: Optional[Any] = None, - target_sizes: Optional[Any] = None, - ) -> List[torch.Tensor | None]: + values: List[torch.Tensor], + ) -> List[bool]: page_indices = self.metadata_client.get_page_indices(self.rank, keys) batch_indices, file_offsets = [], [] @@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage): batch_indices.append(i) file_offsets.append(page_index * self.bytes_per_page) - if target_locations is not None: - for target_location in target_locations: - assert target_location.is_contiguous() - file_results = target_locations - else: - file_results = [ - torch.empty(self.numel, dtype=self.dtype) - for _ in range(len(batch_indices)) - ] + for target_location in values: + assert target_location.is_contiguous() + file_results = values start_time = time.perf_counter() @@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage): ionum / (end_time - start_time) * self.gb_per_page ) - results = [None] * len(keys) - for batch_index, file_result, read_result in zip( - batch_indices, file_results, read_results - ): + results = [False] * len(keys) + for batch_index, read_result in zip(batch_indices, read_results): if read_result == self.bytes_per_page: - results[batch_index] = file_result + results[batch_index] = True else: logger.error( f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed" @@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage): return results - def set( - self, - key: str, - value: Optional[Any] = None, - target_location: Optional[Any] = None, - target_sizes: Optional[Any] = None, - ) -> bool: - return self.batch_set( - [key], - [value] if value is not None else None, - [target_location] if target_location is not None else None, - [target_sizes] if target_sizes is not None else None, - ) - @synchronized() - def batch_set( + def _batch_set( self, keys: List[str], values: Optional[Any] = None, - target_locations: Optional[Any] = None, - target_sizes: Optional[Any] = None, - ) -> bool: + ) -> List[bool]: # In MLA backend, only one rank needs to backup the KV cache if self.skip_backup: return True @@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage): self.rank, written_keys_to_confirm, pages_to_release ) - return all(results) + return results def delete(self, key: str) -> None: self.metadata_client.delete_keys(self.rank, [key]) @@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage): return result[0] if result else False def batch_exists(self, keys: List[str]) -> int: + factor = 1 + if self.is_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + factor = 2 + results = self.metadata_client.exists(self.rank, keys) - for i in range(len(keys)): - if not results[i]: - return i - return len(keys) + i = 0 + while i < len(keys) and results[i]: + i += 1 - def clear(self) -> bool: + return i // factor + + def clear(self) -> None: try: self.metadata_client.clear(self.rank) logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}") - return True except Exception as e: logger.error(f"Failed to clear HiCacheHF3FS: {e}") - return False def close(self) -> None: try: @@ -521,3 +496,139 @@ class HiCacheHF3FS(HiCacheStorage): self.prefetch_bandwidth.clear() self.backup_bandwidth.clear() return storage_metrics + + def register_mem_pool_host(self, mem_pool_host: HostKVCache): + super().register_mem_pool_host(mem_pool_host) + self.is_zero_copy = self.mem_pool_host.layout == "page_first" + logger.info(f"{self.is_zero_copy=}") + + def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]: + _keys = [] + for k in keys: + _keys.append(f"{k}-k") + _keys.append(f"{k}-v") + return _keys + + def _get_mha_zero_copy_values( + self, values: List[torch.Tensor] + ) -> List[torch.Tensor]: + _values = [] + for value in values: + _values.append(value[0]) + _values.append(value[1]) + return _values + + def _batch_get_preprocess(self, keys, host_indices): + page_num = len(host_indices) // self.mem_pool_host.page_size + # host_indices to kv_buffer + flat = not self.is_zero_copy + values = ( + [ + self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat) + for i in range(page_num) + ] + if self.is_zero_copy + else [ + self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num) + ] + ) + + if self.is_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + values = self._get_mha_zero_copy_values(values) + + return keys, values + + def _batch_get_postprocess(self, host_indices, values, results): + page_num = len(host_indices) // self.mem_pool_host.page_size + + if self.is_zero_copy: + if not self.is_mla_model: + results = [ + (results[2 * i] and results[2 * i + 1]) for i in range(page_num) + ] + results = results[:page_num] + return results + + for i in range(page_num): + if not results[i]: + break + self.mem_pool_host.set_from_flat_data_page( + host_indices[i * self.mem_pool_host.page_size], values[i] + ) + + return results + + def batch_get_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + keys, values = self._batch_get_preprocess(keys, host_indices) + results = self._batch_get(keys, values) + return self._batch_get_postprocess(host_indices, values, results) + + def _batch_set_preprocess(self, keys, host_indices): + page_num = len(host_indices) // self.mem_pool_host.page_size + # host_indices to kv_buffer + flat = not self.is_zero_copy + values = [ + self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat) + for i in range(page_num) + ] + + if self.is_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + values = self._get_mha_zero_copy_values(values) + + return keys, values + + def batch_set_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + len_keys = len(keys) + keys, values = self._batch_set_preprocess(keys, host_indices) + results = self._batch_set(keys, values) + return results + + # Deprecated + def get( + self, + key: str, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> torch.Tensor | None: + pass + + # Deprecated + def batch_get( + self, + keys: List[str], + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> List[torch.Tensor | None] | int: + pass + + # Deprecated + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + pass + + # Deprecated + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + pass diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 2704581e6..a2e05b4dd 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -7,7 +7,12 @@ from typing import Any, List, Optional import torch -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorage, + HiCacheStorageConfig, + HiCacheStorageExtraInfo, +) +from sglang.srt.mem_cache.memory_pool_host import HostKVCache DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB @@ -183,7 +188,12 @@ class MooncakeStore(HiCacheStorage): assert self.store.is_exist(warmup_key) == 1 assert self.store.get(warmup_key) == warmup_value - def register_buffer(self, buffer: torch.Tensor) -> None: + def register_mem_pool_host(self, mem_pool_host: HostKVCache): + super().register_mem_pool_host(mem_pool_host) + assert ( + self.mem_pool_host.layout == "page_first" + ), "mooncake store storage backend only support page first layout" + buffer = self.mem_pool_host.kv_buffer try: buffer_ptr = buffer.data_ptr() buffer_size = buffer.numel() * buffer.element_size() @@ -194,6 +204,97 @@ class MooncakeStore(HiCacheStorage): logger.error("Failed to register buffer to Mooncake Store: %s", err) raise TypeError("Mooncake Store Register Buffer Error.") from err + def _get_mha_buffer_meta(self, keys, indices): + ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices) + key_list = [] + for key_ in keys: + key_list.append(f"{key_}_{self.local_rank}_k") + key_list.append(f"{key_}_{self.local_rank}_v") + assert len(key_list) == len(ptr_list) + return key_list, ptr_list, element_size_list + + def _get_mla_buffer_meta(self, keys, indices): + ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices) + key_list = [] + for key_ in keys: + key_list.append(f"{key_}_k") + assert len(key_list) == len(ptr_list) + return key_list, ptr_list, element_size_list + + def _batch_preprocess(self, keys, host_indices): + assert len(keys) > 0 + assert len(keys) == len(host_indices) // self.mem_pool_host.page_size + if self.is_mla_backend: + return self._get_mla_buffer_meta(keys, host_indices) + else: + return self._get_mha_buffer_meta(keys, host_indices) + + def _batch_postprocess(self, results: List[int], is_set_operate=False): + """ + refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h + for batch_get_into, results is Vector of integers, + where each element is the number of bytes read on success, or a negative value on error + for batch_put_from, results is Vector of integers, + where each element is 0 on success, or a negative value on error + """ + if self.is_mla_backend: + return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results] + else: + kv_pairs = zip(results[::2], results[1::2]) + return [ + ( + (k_res == 0 and v_res == 0) + if is_set_operate + else (k_res > 0 and v_res > 0) + ) + for k_res, v_res in kv_pairs + ] + + def batch_get_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices) + get_results = self._get_batch_zero_copy_impl( + key_strs, buffer_ptrs, buffer_sizes + ) + return self._batch_postprocess(get_results, is_set_operate=False) + + def batch_set_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices) + exist_result = self._batch_exist(key_strs) + + set_keys = [] + set_buffer_ptrs = [] + set_buffer_sizes = [] + set_indices = [] + set_results = [-1] * len(keys) + for i in range(len(keys)): + if exist_result[i] != 1: + set_keys.append(keys[i]) + set_buffer_ptrs.append(buffer_ptrs[i]) + set_buffer_sizes.append(buffer_sizes[i]) + set_indices.append(i) + else: + set_results[i] = 0 + + # Only set non-existing keys to storage + if len(set_keys) > 0: + put_results = self._put_batch_zero_copy_impl( + key_strs, buffer_ptrs, buffer_sizes + ) + for i in range(len(set_indices)): + set_results[set_indices[i]] = put_results[i] + + return self._batch_postprocess(set_results, is_set_operate=True) + def set( self, key,