[Metrics] Add KV events publishing (#6098)
This commit is contained in:
@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
|
||||
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def take_events(self):
|
||||
return []
|
||||
|
||||
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.kv_events import (
|
||||
AllBlocksCleared,
|
||||
BlockRemoved,
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue = []
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
self._record_remove_event(x)
|
||||
|
||||
def inc_lock_ref(self, node: TreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
self._record_remove_event(child)
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
new_node.parent = child.parent
|
||||
@@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache):
|
||||
child.key = child.key[split_len:]
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
|
||||
self._record_store_event(new_node)
|
||||
self._record_store_event(child)
|
||||
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
|
||||
new_node.value = value
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
self._record_store_event(new_node)
|
||||
return total_prefix_length
|
||||
|
||||
def _print_helper(self, node: TreeNode, indent: int):
|
||||
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return ret_list
|
||||
|
||||
def _record_store_event(self, node: TreeNode):
|
||||
if self.enable_kv_cache_events:
|
||||
block_hash = hash(tuple(node.key))
|
||||
parent_block_hash = hash(tuple(node.parent.key))
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=[block_hash],
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=node.key,
|
||||
block_size=len(node.key),
|
||||
lora_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
def _record_remove_event(self, node: TreeNode):
|
||||
if self.enable_kv_cache_events:
|
||||
block_hash = hash(tuple(node.key))
|
||||
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
||||
|
||||
def _record_all_cleared_event(self):
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(AllBlocksCleared())
|
||||
|
||||
def take_events(self):
|
||||
"""Atomically takes all events and clears the queue.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
if not self.enable_kv_cache_events:
|
||||
return []
|
||||
events = self.kv_event_queue
|
||||
self.kv_event_queue = []
|
||||
return events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||
|
||||
Reference in New Issue
Block a user