[Refactor] Clean up radix cache related API (#7303)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
prefix_indices: torch.Tensor,
|
||||
host_hit_length: int,
|
||||
mem_quota: Optional[int] = None,
|
||||
):
|
||||
assert (
|
||||
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
||||
), "indices of device kV caches should be on GPU"
|
||||
_ = host_hit_length # unused, but kept for compatibility
|
||||
if last_node.evicted:
|
||||
loading_values = self.load_back(last_node, mem_quota)
|
||||
if loading_values is not None:
|
||||
prefix_indices = (
|
||||
loading_values
|
||||
if len(prefix_indices) == 0
|
||||
else torch.cat([prefix_indices, loading_values])
|
||||
)
|
||||
logger.debug(
|
||||
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
||||
)
|
||||
return loading_values, last_node
|
||||
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
return last_node, prefix_indices
|
||||
return (
|
||||
torch.empty((0,), dtype=torch.int64, device=self.device),
|
||||
last_node,
|
||||
)
|
||||
|
||||
def ready_to_load_cache(self):
|
||||
def ready_to_load_host_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
if include_evicted:
|
||||
return empty_value, self.root_node, self.root_node
|
||||
else:
|
||||
return empty_value, self.root_node
|
||||
return MatchResult(
|
||||
device_indices=empty_value,
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
host_hit_length=0,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
|
||||
else:
|
||||
value = empty_value
|
||||
|
||||
last_node_global = last_node
|
||||
host_hit_length = 0
|
||||
last_host_node = last_node
|
||||
while last_node.evicted:
|
||||
host_hit_length += len(last_node.host_value)
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_host_node,
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
Reference in New Issue
Block a user