diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index e031c3ada..d89d2b634 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool logger = logging.getLogger(__name__) @@ -240,28 +240,38 @@ class HiCacheController: self.io_backend = io_backend self.enable_storage = False - self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost) + # todo: move backend initialization to storage backend module if storage_backend is not None: self.storage_backend_type = storage_backend - from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str + from sglang.srt.mem_cache.hicache_storage import get_hash_str + self.get_hash_str = get_hash_str + + # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. + is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) + # In MLA backend, only one rank needs to backup the KV cache + self.backup_skip = ( + is_mla_backend + # todo: for load balancing, decide which rank to backup the KV cache by hash value + and get_tensor_model_parallel_rank() != 0 + # todo: support other storage backends + and self.storage_backend_type in ["file", "mooncake"] + ) if storage_backend == "file": - self.storage_backend = HiCacheFile(is_mla=self.is_mla) - self.get_hash_str = get_hash_str + from sglang.srt.mem_cache.hicache_storage import HiCacheFile + + self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend) elif storage_backend == "nixl": from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl self.storage_backend = HiCacheNixl() - self.get_hash_str = get_hash_str elif storage_backend == "mooncake": from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import ( MooncakeStore, - get_hash_str_mooncake, ) - self.storage_backend = MooncakeStore(is_mla=self.is_mla) - self.get_hash_str = get_hash_str_mooncake + self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) assert self.mem_pool_host.layout == "page_first" elif storage_backend == "hf3fs": @@ -281,7 +291,6 @@ class HiCacheController: self.storage_backend = HiCacheHF3FS.from_env_config( bytes_per_page, dtype ) - self.get_hash_str = get_hash_str else: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" @@ -400,15 +409,6 @@ class HiCacheController: self.prefetch_thread.start() self.backup_thread.start() - @property - def backup_skip(self): - return ( - self.is_mla - and get_tensor_model_parallel_rank() != 0 - # todo: only support file and mooncake - and self.storage_backend_type in ["file", "mooncake"] - ) - def write( self, device_indices: torch.Tensor, @@ -570,57 +570,91 @@ class HiCacheController: operation.mark_done() return operation.completed_tokens, operation.hash_value - def zerocopy_page_transfer(self, operation, batch_size=8): + # zero copy + def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): hashes, dsts = self.mem_pool_host.get_buffer_with_hash( - operation.hash_value, operation.host_indices + hash_values, host_indices ) - for i in range(0, len(hashes), batch_size): - page_hashes = hashes[i : i + batch_size] - page_dsts = dsts[i : i + batch_size] - page_data = self.storage_backend.batch_get(page_hashes, page_dsts) - if page_data is None: - logger.warning( - f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." - ) - break - completed_tokens = operation.completed_tokens - if operation.increment(self.page_size * len(page_hashes)): - for i in range(len(page_hashes)): - completed_tokens += self.page_size - else: - break + page_data = self.storage_backend.batch_get(hashes, dsts) + if page_data: + operation.increment(self.page_size * len(hashes)) + else: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." + ) - def generic_page_transfer(self, operation, batch_size=8): - for i in range(0, len(operation.hash_value), batch_size): - page_hashes = operation.hash_value[i : i + batch_size] - # todo: zero copy - dummy_page_dst = [ - self.mem_pool_host.get_dummy_flat_data_page() - for _ in range(len(page_hashes)) - ] - page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst) - if page_data is None: - logger.warning( - f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." - ) - break - completed_tokens = operation.completed_tokens - if operation.increment(self.page_size * len(page_hashes)): - for i in range(len(page_hashes)): - self.mem_pool_host.set_from_flat_data_page( - operation.host_indices[completed_tokens], - page_data[i], - ) - completed_tokens += self.page_size - else: - break - - def mooncake_page_transfer(self, operation): + # zero copy + def _mooncake_page_get(self, operation, hash_values, host_indices): key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( - operation.hash_value, operation.host_indices + hash_values, + host_indices, ) - self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes) - operation.increment(len(operation.hash_value) * self.page_size) + get_result = self.storage_backend.batch_get( + key_strs, + target_location=buffer_ptrs, + target_sizes=buffer_sizes, + ) + if get_result != len(hash_values): + logger.warning( + f"Prefetch operation {operation.request_id} failed or partially failed." + ) + if get_result != 0: + operation.increment(get_result * self.page_size) + + # non-zero copy + def _generic_page_get(self, operation, hash_values, host_indices): + # todo: zero copy + dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( + hash_values + ) + page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst) + if page_data is None: + return + for i in range(len(hash_values)): + if page_data[i] is None: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}." + ) + break + self.mem_pool_host.set_from_flat_data_page( + host_indices[operation.completed_tokens], + page_data[i], + ) + if not operation.increment(self.page_size): + break # Operation terminated by controller + + def _page_transfer(self, operation): + # Select the get function and batch size + if self.is_mooncake_backend(): + get_func = self._mooncake_page_get + batch_size = 128 + elif self.storage_backend_type == "hf3fs": + if self.mem_pool_host.layout == "page_first": + get_func = self._3fs_zero_copy_page_get + elif self.mem_pool_host.layout == "layer_first": + get_func = self._generic_page_get + batch_size = 128 + else: + get_func = self._generic_page_get + batch_size = 8 + + # Transfer batch by batch + for i in range(0, len(operation.hash_value), batch_size): + batch_hashes = operation.hash_value[i : i + batch_size] + batch_host_indices = operation.host_indices[ + i * self.page_size : (i + len(batch_hashes)) * self.page_size + ] + prev_completed_tokens = operation.completed_tokens + # Get one batch token, and update the completed_tokens if succeed + get_func(operation, batch_hashes, batch_host_indices) + # Check termination + if ( + operation.completed_tokens + != prev_completed_tokens + len(batch_hashes) * self.page_size + ): + break # Some operations fail or operation terminated by controller + # release pre-allocated memory + self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :]) def is_mooncake_backend(self): return self.storage_backend_type == "mooncake" @@ -632,15 +666,7 @@ class HiCacheController: while not self.stop_event.is_set(): try: operation = self.prefetch_buffer.get(block=True, timeout=1) - if self.is_mooncake_backend(): - self.mooncake_page_transfer(operation) - elif self.storage_backend_type == "hf3fs": - if self.mem_pool_host.layout == "page_first": - self.zerocopy_page_transfer(operation, batch_size=128) - elif self.mem_pool_host.layout == "layer_first": - self.generic_page_transfer(operation, batch_size=128) - else: - self.generic_page_transfer(operation) + self._page_transfer(operation) if self.tp_world_size > 1: # to ensure all TP workers release the host memory at the same time @@ -662,6 +688,27 @@ class HiCacheController: # todo: more sophisticated rate limiting based on storage backend performance return True + def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]: + last_hash = operation.last_hash + tokens_to_fetch = operation.token_ids + + storage_query_count = 0 + remaining_tokens = len(tokens_to_fetch) + hash_value = [] + while remaining_tokens >= self.page_size: + last_hash = self.get_hash_str( + tokens_to_fetch[ + storage_query_count : storage_query_count + self.page_size + ], + last_hash, + ) + hash_value.append(last_hash) + storage_query_count += self.page_size + remaining_tokens -= self.page_size + # deferring to batch exists + hit_page_num = self.storage_backend.batch_exists(hash_value) + return hash_value[:hit_page_num], hit_page_num * self.page_size + def prefetch_thread_func(self): """ Manage prefetching operations from storage backend to host memory. @@ -675,38 +722,12 @@ class HiCacheController: if operation is None: continue - storage_hit_count = 0 if ( operation.host_indices is not None ) and self.prefetch_rate_limit_check(): - last_hash = operation.last_hash - tokens_to_fetch = operation.token_ids - - remaining_tokens = len(tokens_to_fetch) - hash_value = [] - while remaining_tokens >= self.page_size: - last_hash = self.get_hash_str( - tokens_to_fetch[ - storage_hit_count : storage_hit_count + self.page_size - ], - last_hash, - ) - - # todo, more unified interface - if not self.is_mooncake_backend(): - if not self.storage_backend.exists(last_hash): - break - hash_value.append(last_hash) - storage_hit_count += self.page_size - remaining_tokens -= self.page_size - - if self.is_mooncake_backend(): - # deferring to batch exists for mooncake store - exist_result = self.storage_backend.exists(hash_value) - storage_hit_count = ( - sum(1 for v in exist_result.values() if v != 0) - * self.page_size - ) + hash_value, storage_hit_count = self._generic_storage_hit_query( + operation + ) if self.tp_world_size > 1: storage_hit_count_tensor = torch.tensor( @@ -755,59 +776,64 @@ class HiCacheController: self.backup_queue.put(operation) return operation.id - def zerocopy_page_backup(self, operation, batch_size=8): - hashes, dsts = self.mem_pool_host.get_buffer_with_hash( - operation.hash_value, operation.host_indices + # non-zero copy + def _generic_page_set(self, hash_values, host_indices) -> bool: + data = [ + self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size]) + for i in range(len(hash_values)) + ] + return self.storage_backend.batch_set(hash_values, data) + + # zero copy + def _mooncake_page_set(self, hash_values, host_indices) -> bool: + key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( + hash_values, + host_indices, ) - for i in range(0, len(hashes), batch_size): - page_hashes = hashes[i : i + batch_size] - page_data = dsts[i : i + batch_size] - success = self.storage_backend.batch_set(page_hashes, page_data) - if not success: - logger.warning(f"Failed to write page {page_hashes} to storage.") - break - operation.completed_tokens += self.page_size * len(page_hashes) + success = self.storage_backend.batch_set( + key_strs, + target_location=buffer_ptrs, + target_sizes=buffer_sizes, + ) + return success - def generic_page_backup(self, operation, batch_size=8): + # zero copy + def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool: + hashes, dsts = self.mem_pool_host.get_buffer_with_hash( + hash_values, host_indices + ) + return self.storage_backend.batch_set(hashes, dsts) + + # Backup batch by batch + def _page_backup(self, operation): + # Select the set function and batch size + if self.is_mooncake_backend(): + backup_set_func = self._mooncake_page_set + batch_size = 128 + elif self.storage_backend_type == "hf3fs": + if self.mem_pool_host.layout == "page_first": + backup_set_func = self._3fs_zero_copy_page_set + elif self.mem_pool_host.layout == "layer_first": + backup_set_func = self._generic_page_set + batch_size = 128 + else: + backup_set_func = self._generic_page_set + batch_size = 8 + # Backup batch by batch for i in range(0, len(operation.hash_value), batch_size): - page_hashes = operation.hash_value[i : i + batch_size] - page_data = [ - self.mem_pool_host.get_flat_data_page( - operation.host_indices[j * self.page_size] - ) - for j in range(i, i + len(page_hashes)) + batch_hashes = operation.hash_value[i : i + batch_size] + batch_host_indices = operation.host_indices[ + i * self.page_size : (i + len(batch_hashes)) * self.page_size ] - success = self.storage_backend.batch_set(page_hashes, page_data) + # Set one batch token, and record if success. + # todo: allow partial success + success = backup_set_func(batch_hashes, batch_host_indices) if not success: - logger.warning(f"Failed to write page {page_hashes} to storage.") + logger.warning( + f"Write page to storage: {len(batch_hashes)} pages failed." + ) break - operation.completed_tokens += self.page_size * len(page_hashes) - - def mooncake_page_backup(self, operation): - if len(operation.hash_value): - exist_hashvalues = self.storage_backend.exists(operation.hash_value) - indices = operation.host_indices.tolist() - non_exist_keys = [] - non_exist_indices = [] - for i in range(len(operation.hash_value)): - if not exist_hashvalues[operation.hash_value[i]]: - non_exist_keys.append(operation.hash_value[i]) - non_exist_indices.extend( - indices[i * self.page_size : (i + 1) * self.page_size] - ) - if len(non_exist_keys) > 0: - key_strs, buffer_ptrs, buffer_sizes = ( - self.mem_pool_host.get_buffer_meta( - non_exist_keys, non_exist_indices - ) - ) - # TODO: check the return value of batch set to see how many tokens are set successfully - self.storage_backend.batch_set( - key_strs, - target_location=buffer_ptrs, - target_sizes=buffer_sizes, - ) - operation.completed_tokens += len(operation.hash_value) * self.page_size + operation.completed_tokens += self.page_size * len(batch_hashes) def backup_thread_func(self): """ @@ -820,15 +846,7 @@ class HiCacheController: continue if not self.backup_skip: - if self.is_mooncake_backend(): - self.mooncake_page_backup(operation) - elif self.storage_backend_type == "hf3fs": - if self.mem_pool_host.layout == "page_first": - self.zerocopy_page_backup(operation, batch_size=128) - elif self.mem_pool_host.layout == "layer_first": - self.generic_page_backup(operation, batch_size=128) - else: - self.generic_page_backup(operation) + self._page_backup(operation) min_completed_tokens = operation.completed_tokens else: min_completed_tokens = len(operation.token_ids) diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index a391b8acc..907d1b4b8 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -60,7 +60,7 @@ class HiCacheStorage(ABC): keys: List[str], target_locations: Optional[Any] = None, target_sizes: Optional[Any] = None, - ) -> List[torch.Tensor | None]: + ) -> List[torch.Tensor | None] | int: """ Retrieve values for multiple keys. Returns a list of tensors or None for each key. @@ -96,17 +96,28 @@ class HiCacheStorage(ABC): pass @abstractmethod - def exists(self, key: str) -> bool | dict: + def exists(self, key: str) -> bool: """ Check if the key exists in the storage. Returns True if the key exists, False otherwise. """ pass + def batch_exists(self, keys: List[str]) -> int: + """ + Check if the keys exist in the storage. + return the number of consecutive existing keys from the start. + Can be overridden by subclasses for more efficient implementation. + """ + for i in range(len(keys)): + if not self.exists(keys[i]): + return i + return len(keys) + class HiCacheFile(HiCacheStorage): - def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False): + def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) if is_dp_attention_enabled(): tp_rank = get_attention_tp_rank() @@ -115,7 +126,9 @@ class HiCacheFile(HiCacheStorage): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" + self.tp_suffix = ( + f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else "" + ) if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index a2cc5bd37..13b707ba7 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache): raise ValueError(f"Unsupported layout: {self.layout}") def get_buffer_meta(self, keys, indices): + local_rank = get_tensor_model_parallel_rank() ptr_list = [] key_list = [] kv_buffer_data_ptr = self.kv_buffer.data_ptr() @@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache): ptr_list.append(k_ptr) ptr_list.append(v_ptr) key_ = keys[index // self.page_size] - key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k") - key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v") + key_list.append(f"{key_}_{local_rank}_k") + key_list.append(f"{key_}_{local_rank}_v") element_size = ( self.layer_num * self.dtype.itemsize @@ -704,6 +705,7 @@ class MLATokenToKVPoolHost(HostKVCache): raise ValueError(f"Unsupported layout: {self.layout}") def get_buffer_meta(self, keys, indices): + local_rank = get_tensor_model_parallel_rank() ptr_list = [] key_list = [] kv_buffer_data_ptr = self.kv_buffer.data_ptr() @@ -717,7 +719,7 @@ class MLATokenToKVPoolHost(HostKVCache): ) ptr_list.append(k_ptr) key_ = keys[index // self.page_size] - key_list.append(f"{key_}_k") + key_list.append(f"{key_}_{local_rank}_k") element_size = ( self.layer_num * self.dtype.itemsize diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/README.md b/python/sglang/srt/mem_cache/storage/mooncake_store/README.md index 6ad71821e..e42bffcfd 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/README.md +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/README.md @@ -55,12 +55,11 @@ Launch Mooncake meta server: python -m mooncake.http_metadata_server ``` -Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables: +Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables. Note that, for optimal performance, the Mooncake backend currently supports only the `page_first` layout. ```bash MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \ MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \ -MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \ MOONCAKE_PROTOCOL="rdma" \ MOONCAKE_DEVICE="erdma_0,erdma_1" \ MOONCAKE_MASTER=127.0.0.1:50051 \ diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 1cddd0092..704f6787e 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -13,21 +13,11 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.mem_cache.hicache_storage import HiCacheStorage DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB +DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB logger = logging.getLogger(__name__) -def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None): - prefix_str = "" - if prior_hash: - prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest() - current_token_ids_bytes = np.array(token_ids).tobytes() - current_hash_object = hashlib.sha256(current_token_ids_bytes) - current_hash_hex = current_hash_object.hexdigest() - return f"{prefix_str}_{int(current_hash_hex[:16], 16)}" - - @dataclass class MooncakeStoreConfig: local_hostname: str @@ -54,9 +44,8 @@ class MooncakeStoreConfig: global_segment_size=config.get( "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE ), - local_buffer_size=config.get( - "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE - ), + # Zero copy interface does not need local buffer + local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE, protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", "auto"), master_server_address=config.get("master_server_address"), @@ -79,9 +68,8 @@ class MooncakeStoreConfig: global_segment_size=int( os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE) ), - local_buffer_size=int( - os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE) - ), + # Zero copy interface does not need local buffer + local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE, protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), device_name=os.getenv("MOONCAKE_DEVICE", "auto"), master_server_address=os.getenv("MOONCAKE_MASTER"), @@ -96,7 +84,15 @@ class MooncakeStoreConfig: class MooncakeStore(HiCacheStorage): - def __init__(self, is_mla: bool = False): + def __init__(self, is_mla_backend: bool = False): + """ + Initialize MooncakeStore. + + Args: + is_mla_backend: If the backend is MLA + """ + self.is_mla_backend = is_mla_backend + try: from mooncake.store import MooncakeDistributedStore except ImportError as e: @@ -126,7 +122,6 @@ class MooncakeStore(HiCacheStorage): logger.info("Connect to Mooncake store successfully.") self.warmup() logger.info("Mooncake store warmup successfully.") - self.is_mla = is_mla except ValueError as e: logger.error("Configuration loading failed: %s", e) @@ -135,14 +130,14 @@ class MooncakeStore(HiCacheStorage): logger.error("An error occurred while loading the configuration: %s", exc) raise + self.local_rank = get_tensor_model_parallel_rank() + def warmup(self): warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex - # 10 MB - warmup_value = bytes(10 * 1024 * 1024) - self.store.put(warmup_key, warmup_value) + warmup_value = bytes(4 * 1024) # 4 KB + assert self.store.put(warmup_key, warmup_value) == 0 assert self.store.is_exist(warmup_key) == 1 - self.store.get(warmup_key) - self.store.remove(warmup_key) + assert self.store.get(warmup_key) == warmup_value def register_buffer(self, buffer: torch.Tensor) -> None: try: @@ -162,78 +157,95 @@ class MooncakeStore(HiCacheStorage): target_location: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None, ) -> bool: - assert len(key) == len(target_location) == len(target_sizes) - if len(key) == 0: - return - - for i in range(len(key)): - if key[i] is None or target_location[i] is None or target_sizes[i] is None: - return - - self._put_batch_zero_copy_impl(key, target_location, target_sizes) + return self.batch_set([key], [value], [target_location], [target_sizes]) def batch_set( self, keys: List[str], - value: Optional[Any] = None, target_location: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None, ) -> bool: assert len(keys) == len(target_location) == len(target_sizes) if len(keys) == 0: - return + return False for i in range(len(keys)): if keys[i] is None or target_location[i] is None or target_sizes[i] is None: - return + return False - self._put_batch_zero_copy_impl(keys, target_location, target_sizes) + exist_result = self._batch_exist(keys) + set_keys = [] + set_target_locations = [] + set_target_sizes = [] + set_indices = [] + for i in range(len(keys)): + if exist_result[i] != 1: + set_keys.append(keys[i]) + set_target_locations.append(target_location[i]) + set_target_sizes.append(target_sizes[i]) + set_indices.append(i) + # Only set non-existing keys to storage + put_result = self._put_batch_zero_copy_impl( + set_keys, set_target_locations, set_target_sizes + ) + for i in range(len(set_indices)): + if put_result[i] == 0: + exist_result[set_indices[i]] = 1 + + success_count = 0 + for i in range(len(keys)): + if exist_result[i] == 0: + break + success_count += 1 + # TODO: return the number of consecutive successful operations from the start. + return success_count == len(keys) def get( self, key, target_location: Optional[Any] = None, target_sizes: Optional[Any] = None, - ) -> torch.Tensor | None: - assert len(key) == len(target_location) == len(target_sizes) - if len(key) == 0: - return - - for i in range(len(key)): - if key[i] is None or target_location[i] is None or target_sizes[i] is None: - return - - return self._get_batch_zero_copy_impl(key, target_location, target_sizes) + ) -> bool: + return self.batch_get([key], [target_location], [target_sizes]) == 1 def batch_get( self, keys: List[str], target_location: Optional[Any] = None, target_sizes: Optional[Any] = None, - ) -> torch.Tensor | None: + ) -> int: assert len(keys) == len(target_location) == len(target_sizes) if len(keys) == 0: - return - + return 0 + get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes) + if self.is_mla_backend: + key_multiplier = 1 + else: + key_multiplier = 2 for i in range(len(keys)): - if keys[i] is None or target_location[i] is None or target_sizes[i] is None: - return + if get_result[i] < 0: + return i // key_multiplier + return len(keys) // key_multiplier - return self._get_batch_zero_copy_impl(keys, target_location, target_sizes) + def exists(self, key) -> bool: + return self.batch_exists([key]) > 0 - def exists(self, keys) -> bool | dict: - _keys = [] - local_rank = get_tensor_model_parallel_rank() - for key in keys: - if key is None: - return None + def batch_exists(self, keys) -> int: + if self.is_mla_backend: + query_keys = [f"{key}_k" for key in keys] + key_multiplier = 1 + else: + query_keys = [] + for key in keys: + query_keys.append(f"{key}_{self.local_rank}_k") + query_keys.append(f"{key}_{self.local_rank}_v") + key_multiplier = 2 - if self.is_mla: - _keys.append(f"{key}_k") - else: - _keys.append(f"{key}_{local_rank}_k") - result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} - return result + exist_result = self._batch_exist(query_keys) + for i in range(len(query_keys)): + if exist_result[i] != 1: + return i // key_multiplier + return len(query_keys) // key_multiplier def delete(self, key) -> None: raise (NotImplementedError) @@ -248,18 +260,13 @@ class MooncakeStore(HiCacheStorage): def _put_batch_zero_copy_impl( self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] - ) -> None: - try: - self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes) - except TypeError as err: - logger.error("Failed to put value to Mooncake Store: %s", err) - raise TypeError("Mooncake Store Put Type Error.") from err + ) -> List[int]: + return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes) def _get_batch_zero_copy_impl( self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] - ) -> None: - try: - self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes) - except TypeError as err: - logger.error("Failed to get value from Mooncake Store: %s", err) - raise TypeError("Mooncake Store Get Type Error.") from err + ) -> List[int]: + return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes) + + def _batch_exist(self, key_strs: List[str]) -> List[int]: + return self.store.batch_is_exist(key_strs)