[feat]: Emit fixed-size KV blocks events (#6824)
This commit is contained in:
@@ -461,23 +461,47 @@ class RadixCache(BasePrefixCache):
|
|||||||
return ret_list
|
return ret_list
|
||||||
|
|
||||||
def _record_store_event(self, node: TreeNode):
|
def _record_store_event(self, node: TreeNode):
|
||||||
|
# One BlockStored per ``page_size`` chunk.
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
block_hash = hash(tuple(node.key))
|
# First chunk links to the last page of the parent node (if any).
|
||||||
parent_block_hash = hash(tuple(node.parent.key))
|
if node.parent is None:
|
||||||
self.kv_event_queue.append(
|
parent_block_hash = None
|
||||||
BlockStored(
|
else:
|
||||||
block_hashes=[block_hash],
|
last_page_start = (
|
||||||
parent_block_hash=parent_block_hash,
|
(len(node.parent.key) - 1) // self.page_size
|
||||||
token_ids=node.key,
|
) * self.page_size
|
||||||
block_size=len(node.key),
|
parent_parent_tokens = node.parent.key[last_page_start:]
|
||||||
lora_id=None,
|
parent_block_hash = hash(tuple(parent_parent_tokens))
|
||||||
|
|
||||||
|
for start in range(0, len(node.key), self.page_size):
|
||||||
|
page_tokens = node.key[start : start + self.page_size]
|
||||||
|
if not page_tokens:
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_hash = hash(tuple(page_tokens))
|
||||||
|
|
||||||
|
self.kv_event_queue.append(
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[block_hash],
|
||||||
|
parent_block_hash=parent_block_hash,
|
||||||
|
token_ids=page_tokens,
|
||||||
|
block_size=len(page_tokens),
|
||||||
|
lora_id=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Chain next chunk to this one.
|
||||||
|
parent_block_hash = block_hash
|
||||||
|
|
||||||
def _record_remove_event(self, node: TreeNode):
|
def _record_remove_event(self, node: TreeNode):
|
||||||
|
# One BlockRemoved per chunk.
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
block_hash = hash(tuple(node.key))
|
for start in range(0, len(node.key), self.page_size):
|
||||||
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
page_tokens = node.key[start : start + self.page_size]
|
||||||
|
if not page_tokens:
|
||||||
|
continue
|
||||||
|
block_hash = hash(tuple(page_tokens))
|
||||||
|
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
||||||
|
|
||||||
def _record_all_cleared_event(self):
|
def _record_all_cleared_event(self):
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
|
|||||||
Reference in New Issue
Block a user