HiCache Storage: generate hash when inserting new nodes (#9053)
This commit is contained in:
@@ -169,12 +169,13 @@ class StorageOperation:
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
hash_value: Optional[List[str]] = None,
|
||||
):
|
||||
self.host_indices = host_indices
|
||||
self.token_ids = token_ids
|
||||
self.last_hash = last_hash
|
||||
self.completed_tokens = 0
|
||||
self.hash_value = []
|
||||
self.hash_value = hash_value if hash_value is not None else []
|
||||
|
||||
self.id = StorageOperation.counter
|
||||
StorageOperation.counter += 1
|
||||
@@ -702,12 +703,12 @@ class HiCacheController:
|
||||
self,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
hash_value: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Write KV caches from host memory to storage backend.
|
||||
"""
|
||||
operation = StorageOperation(host_indices, token_ids, last_hash)
|
||||
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
@@ -762,24 +763,6 @@ class HiCacheController:
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
backup_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_backup)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_backup[
|
||||
backup_hit_count : backup_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
backup_hit_count += self.page_size
|
||||
hash_value.append(last_hash)
|
||||
remaining_tokens -= self.page_size
|
||||
operation.hash_value = hash_value
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
@@ -802,7 +785,6 @@ class HiCacheController:
|
||||
self.ack_backup_queue.put(
|
||||
(
|
||||
operation.id,
|
||||
operation.hash_value[: min_completed_tokens // self.page_size],
|
||||
min_completed_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
|
||||
|
||||
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
||||
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
if prior_hash:
|
||||
|
||||
@@ -151,7 +151,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.parent.get_last_hash_value()
|
||||
node.host_value, node.key, node.hash_value
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
@@ -414,18 +414,18 @@ class HiRadixCache(RadixCache):
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, hash_value, completed_tokens = (
|
||||
self.cache_controller.ack_backup_queue.get()
|
||||
)
|
||||
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
||||
host_node = self.ongoing_backup[ack_id]
|
||||
if completed_tokens == 0:
|
||||
host_node.hash_value = None
|
||||
elif completed_tokens < len(host_node.key):
|
||||
# backup is only partially successful, split the node
|
||||
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
||||
new_node.hash_value = hash_value
|
||||
else:
|
||||
host_node.hash_value = hash_value
|
||||
|
||||
if completed_tokens > 0:
|
||||
if completed_tokens < len(host_node.key):
|
||||
# backup is only partially successful, split the node
|
||||
new_node = self._split_node(
|
||||
host_node.key, host_node, completed_tokens
|
||||
)
|
||||
new_node.backuped_storage = True
|
||||
else:
|
||||
host_node.backuped_storage = True
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
@@ -717,6 +717,21 @@ class HiRadixCache(RadixCache):
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
|
||||
if self.enable_storage:
|
||||
last_hash = node.get_last_hash_value()
|
||||
assert (node == self.root_node) or (
|
||||
last_hash is not None
|
||||
), "Parent node must have a hash value with storage enabled"
|
||||
new_node.hash_value = []
|
||||
for idx in range(0, len(key), self.page_size):
|
||||
new_node.hash_value.append(
|
||||
self.cache_controller.get_hash_str(
|
||||
key[idx : idx + self.page_size],
|
||||
prior_hash=last_hash,
|
||||
)
|
||||
)
|
||||
last_hash = new_node.hash_value[-1]
|
||||
|
||||
if self.cache_controller.write_policy != "write_back":
|
||||
self.inc_hit_count(new_node)
|
||||
return total_prefix_length
|
||||
|
||||
@@ -62,6 +62,7 @@ class TreeNode:
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
# store hash values of each pages
|
||||
self.hash_value: Optional[List[str]] = None
|
||||
self.backuped_storage = False
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
@@ -74,10 +75,6 @@ class TreeNode:
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
@property
|
||||
def backuped_storage(self):
|
||||
return self.hash_value is not None and len(self.hash_value) > 0
|
||||
|
||||
def protect_host(self):
|
||||
"""Protect the host value from eviction."""
|
||||
self.host_ref_counter += 1
|
||||
|
||||
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
|
||||
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
||||
local_rank = get_tensor_model_parallel_rank()
|
||||
prefix_str = ""
|
||||
if prefix_block_key:
|
||||
if len(prefix_block_key):
|
||||
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
|
||||
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
||||
if prior_hash:
|
||||
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
||||
current_token_ids_bytes = np.array(token_ids).tobytes()
|
||||
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
||||
current_hash_hex = current_hash_object.hexdigest()
|
||||
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
||||
|
||||
Reference in New Issue
Block a user