HiCache Storage: generate hash when inserting new nodes (#9053)
This commit is contained in:
@@ -169,12 +169,13 @@ class StorageOperation:
|
|||||||
host_indices: torch.Tensor,
|
host_indices: torch.Tensor,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
last_hash: Optional[str] = None,
|
last_hash: Optional[str] = None,
|
||||||
|
hash_value: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self.host_indices = host_indices
|
self.host_indices = host_indices
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
self.last_hash = last_hash
|
self.last_hash = last_hash
|
||||||
self.completed_tokens = 0
|
self.completed_tokens = 0
|
||||||
self.hash_value = []
|
self.hash_value = hash_value if hash_value is not None else []
|
||||||
|
|
||||||
self.id = StorageOperation.counter
|
self.id = StorageOperation.counter
|
||||||
StorageOperation.counter += 1
|
StorageOperation.counter += 1
|
||||||
@@ -702,12 +703,12 @@ class HiCacheController:
|
|||||||
self,
|
self,
|
||||||
host_indices: torch.Tensor,
|
host_indices: torch.Tensor,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
last_hash: Optional[str] = None,
|
hash_value: Optional[List[str]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Write KV caches from host memory to storage backend.
|
Write KV caches from host memory to storage backend.
|
||||||
"""
|
"""
|
||||||
operation = StorageOperation(host_indices, token_ids, last_hash)
|
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
||||||
self.backup_queue.put(operation)
|
self.backup_queue.put(operation)
|
||||||
return operation.id
|
return operation.id
|
||||||
|
|
||||||
@@ -762,24 +763,6 @@ class HiCacheController:
|
|||||||
if operation is None:
|
if operation is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last_hash = operation.last_hash
|
|
||||||
tokens_to_backup = operation.token_ids
|
|
||||||
|
|
||||||
backup_hit_count = 0
|
|
||||||
remaining_tokens = len(tokens_to_backup)
|
|
||||||
hash_value = []
|
|
||||||
while remaining_tokens >= self.page_size:
|
|
||||||
last_hash = self.get_hash_str(
|
|
||||||
tokens_to_backup[
|
|
||||||
backup_hit_count : backup_hit_count + self.page_size
|
|
||||||
],
|
|
||||||
last_hash,
|
|
||||||
)
|
|
||||||
backup_hit_count += self.page_size
|
|
||||||
hash_value.append(last_hash)
|
|
||||||
remaining_tokens -= self.page_size
|
|
||||||
operation.hash_value = hash_value
|
|
||||||
|
|
||||||
if self.is_mooncake_backend():
|
if self.is_mooncake_backend():
|
||||||
self.mooncake_page_backup(operation)
|
self.mooncake_page_backup(operation)
|
||||||
elif self.storage_backend_type == "hf3fs":
|
elif self.storage_backend_type == "hf3fs":
|
||||||
@@ -802,7 +785,6 @@ class HiCacheController:
|
|||||||
self.ack_backup_queue.put(
|
self.ack_backup_queue.put(
|
||||||
(
|
(
|
||||||
operation.id,
|
operation.id,
|
||||||
operation.hash_value[: min_completed_tokens // self.page_size],
|
|
||||||
min_completed_tokens,
|
min_completed_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
|
|
||||||
if prior_hash:
|
if prior_hash:
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def write_backup_storage(self, node: TreeNode):
|
def write_backup_storage(self, node: TreeNode):
|
||||||
operation_id = self.cache_controller.write_storage(
|
operation_id = self.cache_controller.write_storage(
|
||||||
node.host_value, node.key, node.parent.get_last_hash_value()
|
node.host_value, node.key, node.hash_value
|
||||||
)
|
)
|
||||||
self.ongoing_backup[operation_id] = node
|
self.ongoing_backup[operation_id] = node
|
||||||
node.protect_host()
|
node.protect_host()
|
||||||
@@ -414,18 +414,18 @@ class HiRadixCache(RadixCache):
|
|||||||
group=self.tp_group,
|
group=self.tp_group,
|
||||||
)
|
)
|
||||||
for _ in range(queue_size.item()):
|
for _ in range(queue_size.item()):
|
||||||
ack_id, hash_value, completed_tokens = (
|
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
||||||
self.cache_controller.ack_backup_queue.get()
|
|
||||||
)
|
|
||||||
host_node = self.ongoing_backup[ack_id]
|
host_node = self.ongoing_backup[ack_id]
|
||||||
if completed_tokens == 0:
|
|
||||||
host_node.hash_value = None
|
if completed_tokens > 0:
|
||||||
elif completed_tokens < len(host_node.key):
|
if completed_tokens < len(host_node.key):
|
||||||
# backup is only partially successful, split the node
|
# backup is only partially successful, split the node
|
||||||
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
new_node = self._split_node(
|
||||||
new_node.hash_value = hash_value
|
host_node.key, host_node, completed_tokens
|
||||||
else:
|
)
|
||||||
host_node.hash_value = hash_value
|
new_node.backuped_storage = True
|
||||||
|
else:
|
||||||
|
host_node.backuped_storage = True
|
||||||
host_node.release_host()
|
host_node.release_host()
|
||||||
del self.ongoing_backup[ack_id]
|
del self.ongoing_backup[ack_id]
|
||||||
|
|
||||||
@@ -717,6 +717,21 @@ class HiRadixCache(RadixCache):
|
|||||||
node.children[child_key] = new_node
|
node.children[child_key] = new_node
|
||||||
self.evictable_size_ += len(value)
|
self.evictable_size_ += len(value)
|
||||||
|
|
||||||
|
if self.enable_storage:
|
||||||
|
last_hash = node.get_last_hash_value()
|
||||||
|
assert (node == self.root_node) or (
|
||||||
|
last_hash is not None
|
||||||
|
), "Parent node must have a hash value with storage enabled"
|
||||||
|
new_node.hash_value = []
|
||||||
|
for idx in range(0, len(key), self.page_size):
|
||||||
|
new_node.hash_value.append(
|
||||||
|
self.cache_controller.get_hash_str(
|
||||||
|
key[idx : idx + self.page_size],
|
||||||
|
prior_hash=last_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_hash = new_node.hash_value[-1]
|
||||||
|
|
||||||
if self.cache_controller.write_policy != "write_back":
|
if self.cache_controller.write_policy != "write_back":
|
||||||
self.inc_hit_count(new_node)
|
self.inc_hit_count(new_node)
|
||||||
return total_prefix_length
|
return total_prefix_length
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class TreeNode:
|
|||||||
self.host_value: Optional[torch.Tensor] = None
|
self.host_value: Optional[torch.Tensor] = None
|
||||||
# store hash values of each pages
|
# store hash values of each pages
|
||||||
self.hash_value: Optional[List[str]] = None
|
self.hash_value: Optional[List[str]] = None
|
||||||
|
self.backuped_storage = False
|
||||||
|
|
||||||
self.id = TreeNode.counter if id is None else id
|
self.id = TreeNode.counter if id is None else id
|
||||||
TreeNode.counter += 1
|
TreeNode.counter += 1
|
||||||
@@ -74,10 +75,6 @@ class TreeNode:
|
|||||||
def backuped(self):
|
def backuped(self):
|
||||||
return self.host_value is not None
|
return self.host_value is not None
|
||||||
|
|
||||||
@property
|
|
||||||
def backuped_storage(self):
|
|
||||||
return self.hash_value is not None and len(self.hash_value) > 0
|
|
||||||
|
|
||||||
def protect_host(self):
|
def protect_host(self):
|
||||||
"""Protect the host value from eviction."""
|
"""Protect the host value from eviction."""
|
||||||
self.host_ref_counter += 1
|
self.host_ref_counter += 1
|
||||||
|
|||||||
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
|
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
||||||
local_rank = get_tensor_model_parallel_rank()
|
local_rank = get_tensor_model_parallel_rank()
|
||||||
prefix_str = ""
|
prefix_str = ""
|
||||||
if prefix_block_key:
|
if prior_hash:
|
||||||
if len(prefix_block_key):
|
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
||||||
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
|
current_token_ids_bytes = np.array(token_ids).tobytes()
|
||||||
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
|
||||||
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
||||||
current_hash_hex = current_hash_object.hexdigest()
|
current_hash_hex = current_hash_object.hexdigest()
|
||||||
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
||||||
|
|||||||
Reference in New Issue
Block a user