[Metrics] Add KV events publishing (#6098)

This commit is contained in:
Trevor Morris
2025-05-19 14:19:54 -07:00
committed by GitHub
parent 299fd22f9e
commit 7adf245ba2
7 changed files with 686 additions and 1 deletions

View File

@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
def pretty_print(self):
raise NotImplementedError()
def take_events(self):
return []

View File

@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared,
BlockRemoved,
BlockStored,
KVCacheEvent,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1
self.evictable_size_ = 0
self.protected_size_ = 0
self._record_all_cleared_event()
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
self._record_remove_event(x)
def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
self._record_remove_event(child)
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
@@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache):
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
self._record_store_event(new_node)
self._record_store_event(child)
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
self._record_store_event(new_node)
return total_prefix_length
def _print_helper(self, node: TreeNode, indent: int):
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
return ret_list
def _record_store_event(self, node: TreeNode):
if self.enable_kv_cache_events:
block_hash = hash(tuple(node.key))
parent_block_hash = hash(tuple(node.parent.key))
self.kv_event_queue.append(
BlockStored(
block_hashes=[block_hash],
parent_block_hash=parent_block_hash,
token_ids=node.key,
block_size=len(node.key),
lora_id=None,
)
)
def _record_remove_event(self, node: TreeNode):
if self.enable_kv_cache_events:
block_hash = hash(tuple(node.key))
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
def _record_all_cleared_event(self):
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
def take_events(self):
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events
if __name__ == "__main__":
tree = RadixCache(None, None, page_size=1, disable=False)