From 777688b8929c877e4e28c2eac208d776abe4c3af Mon Sep 17 00:00:00 2001 From: Faradawn Yang <73060648+faradawn@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:07:58 -0500 Subject: [PATCH] [feat]: Emit fixed-size KV blocks events (#6824) --- python/sglang/srt/mem_cache/radix_cache.py | 48 ++++++++++++++++------ 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index bdcb7640f..377784302 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -461,23 +461,47 @@ class RadixCache(BasePrefixCache): return ret_list def _record_store_event(self, node: TreeNode): + # One BlockStored per ``page_size`` chunk. if self.enable_kv_cache_events: - block_hash = hash(tuple(node.key)) - parent_block_hash = hash(tuple(node.parent.key)) - self.kv_event_queue.append( - BlockStored( - block_hashes=[block_hash], - parent_block_hash=parent_block_hash, - token_ids=node.key, - block_size=len(node.key), - lora_id=None, + # First chunk links to the last page of the parent node (if any). + if node.parent is None: + parent_block_hash = None + else: + last_page_start = ( + (len(node.parent.key) - 1) // self.page_size + ) * self.page_size + parent_parent_tokens = node.parent.key[last_page_start:] + parent_block_hash = hash(tuple(parent_parent_tokens)) + + for start in range(0, len(node.key), self.page_size): + page_tokens = node.key[start : start + self.page_size] + if not page_tokens: + continue + + block_hash = hash(tuple(page_tokens)) + + self.kv_event_queue.append( + BlockStored( + block_hashes=[block_hash], + parent_block_hash=parent_block_hash, + token_ids=page_tokens, + block_size=len(page_tokens), + lora_id=None, + ) ) - ) + + # Chain next chunk to this one. + parent_block_hash = block_hash def _record_remove_event(self, node: TreeNode): + # One BlockRemoved per chunk. if self.enable_kv_cache_events: - block_hash = hash(tuple(node.key)) - self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash])) + for start in range(0, len(node.key), self.page_size): + page_tokens = node.key[start : start + self.page_size] + if not page_tokens: + continue + block_hash = hash(tuple(page_tokens)) + self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash])) def _record_all_cleared_event(self): if self.enable_kv_cache_events: