[Refactor] Clean up radix cache related API (#7303)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
DarkSharpness
2025-06-19 09:58:48 -07:00
committed by GitHub
parent 650127a173
commit 47367b768d
7 changed files with 153 additions and 122 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
def init_load_back(
self,
last_node: TreeNode,
prefix_indices: torch.Tensor,
host_hit_length: int,
mem_quota: Optional[int] = None,
):
assert (
len(prefix_indices) == 0 or prefix_indices.is_cuda
), "indices of device kV caches should be on GPU"
_ = host_hit_length # unused, but kept for compatibility
if last_node.evicted:
loading_values = self.load_back(last_node, mem_quota)
if loading_values is not None:
prefix_indices = (
loading_values
if len(prefix_indices) == 0
else torch.cat([prefix_indices, loading_values])
)
logger.debug(
f"loading back {len(loading_values)} tokens for node {last_node.id}"
)
return loading_values, last_node
while last_node.evicted:
last_node = last_node.parent
return last_node, prefix_indices
return (
torch.empty((0,), dtype=torch.int64, device=self.device),
last_node,
)
def ready_to_load_cache(self):
def ready_to_load_host_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
def check_hicache_events(self):
self.writing_check()
self.loading_check()
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0:
if include_evicted:
return empty_value, self.root_node, self.root_node
else:
return empty_value, self.root_node
return MatchResult(
device_indices=empty_value,
last_device_node=self.root_node,
last_host_node=self.root_node,
host_hit_length=0,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
else:
value = empty_value
last_node_global = last_node
host_hit_length = 0
last_host_node = last_node
while last_node.evicted:
host_hit_length += len(last_node.host_value)
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_host_node,
host_hit_length=host_hit_length,
)
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.monotonic()