diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 57b0a47c4..1edc88751 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -169,12 +169,13 @@ class StorageOperation: host_indices: torch.Tensor, token_ids: List[int], last_hash: Optional[str] = None, + hash_value: Optional[List[str]] = None, ): self.host_indices = host_indices self.token_ids = token_ids self.last_hash = last_hash self.completed_tokens = 0 - self.hash_value = [] + self.hash_value = hash_value if hash_value is not None else [] self.id = StorageOperation.counter StorageOperation.counter += 1 @@ -702,12 +703,12 @@ class HiCacheController: self, host_indices: torch.Tensor, token_ids: List[int], - last_hash: Optional[str] = None, + hash_value: Optional[List[str]] = None, ) -> int: """ Write KV caches from host memory to storage backend. """ - operation = StorageOperation(host_indices, token_ids, last_hash) + operation = StorageOperation(host_indices, token_ids, hash_value=hash_value) self.backup_queue.put(operation) return operation.id @@ -762,24 +763,6 @@ class HiCacheController: if operation is None: continue - last_hash = operation.last_hash - tokens_to_backup = operation.token_ids - - backup_hit_count = 0 - remaining_tokens = len(tokens_to_backup) - hash_value = [] - while remaining_tokens >= self.page_size: - last_hash = self.get_hash_str( - tokens_to_backup[ - backup_hit_count : backup_hit_count + self.page_size - ], - last_hash, - ) - backup_hit_count += self.page_size - hash_value.append(last_hash) - remaining_tokens -= self.page_size - operation.hash_value = hash_value - if self.is_mooncake_backend(): self.mooncake_page_backup(operation) elif self.storage_backend_type == "hf3fs": @@ -802,7 +785,6 @@ class HiCacheController: self.ack_backup_queue.put( ( operation.id, - operation.hash_value[: min_completed_tokens // self.page_size], min_completed_tokens, ) ) diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 8ebdecfda..90a468cc3 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -15,7 +15,7 @@ from sglang.srt.distributed import ( ) -def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: +def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: hasher = hashlib.sha256() if prior_hash: diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index e11b9e64d..342ca7dd2 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -151,7 +151,7 @@ class HiRadixCache(RadixCache): def write_backup_storage(self, node: TreeNode): operation_id = self.cache_controller.write_storage( - node.host_value, node.key, node.parent.get_last_hash_value() + node.host_value, node.key, node.hash_value ) self.ongoing_backup[operation_id] = node node.protect_host() @@ -414,18 +414,18 @@ class HiRadixCache(RadixCache): group=self.tp_group, ) for _ in range(queue_size.item()): - ack_id, hash_value, completed_tokens = ( - self.cache_controller.ack_backup_queue.get() - ) + ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get() host_node = self.ongoing_backup[ack_id] - if completed_tokens == 0: - host_node.hash_value = None - elif completed_tokens < len(host_node.key): - # backup is only partially successful, split the node - new_node = self._split_node(host_node.key, host_node, completed_tokens) - new_node.hash_value = hash_value - else: - host_node.hash_value = hash_value + + if completed_tokens > 0: + if completed_tokens < len(host_node.key): + # backup is only partially successful, split the node + new_node = self._split_node( + host_node.key, host_node, completed_tokens + ) + new_node.backuped_storage = True + else: + host_node.backuped_storage = True host_node.release_host() del self.ongoing_backup[ack_id] @@ -717,6 +717,21 @@ class HiRadixCache(RadixCache): node.children[child_key] = new_node self.evictable_size_ += len(value) + if self.enable_storage: + last_hash = node.get_last_hash_value() + assert (node == self.root_node) or ( + last_hash is not None + ), "Parent node must have a hash value with storage enabled" + new_node.hash_value = [] + for idx in range(0, len(key), self.page_size): + new_node.hash_value.append( + self.cache_controller.get_hash_str( + key[idx : idx + self.page_size], + prior_hash=last_hash, + ) + ) + last_hash = new_node.hash_value[-1] + if self.cache_controller.write_policy != "write_back": self.inc_hit_count(new_node) return total_prefix_length diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 0826990c2..21fe3b2b9 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -62,6 +62,7 @@ class TreeNode: self.host_value: Optional[torch.Tensor] = None # store hash values of each pages self.hash_value: Optional[List[str]] = None + self.backuped_storage = False self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @@ -74,10 +75,6 @@ class TreeNode: def backuped(self): return self.host_value is not None - @property - def backuped_storage(self): - return self.hash_value is not None and len(self.hash_value) > 0 - def protect_host(self): """Protect the host value from eviction.""" self.host_ref_counter += 1 diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 05dc7a3ce..38700d55e 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB logger = logging.getLogger(__name__) -def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str): +def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None): local_rank = get_tensor_model_parallel_rank() prefix_str = "" - if prefix_block_key: - if len(prefix_block_key): - prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest() - current_token_ids_bytes = np.array(current_page_ids).tobytes() + if prior_hash: + prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest() + current_token_ids_bytes = np.array(token_ids).tobytes() current_hash_object = hashlib.sha256(current_token_ids_bytes) current_hash_hex = current_hash_object.hexdigest() return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"