[Refactor] Clean up radix cache related API (#7303)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
DarkSharpness
2025-06-19 09:58:48 -07:00
committed by GitHub
parent 650127a173
commit 47367b768d
7 changed files with 153 additions and 122 deletions

View File

@@ -1,5 +1,31 @@
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
else:
Req = Any # Placeholder for Req type when not type checking
class MatchResult(NamedTuple):
"""Result of a prefix match operation.
Attributes:
device_indices : Indices of the KV cache on the device matched by common prefix.
last_device_node: The last TreeNode on the device that was matched.
last_host_node : The last TreeNode on the host that was matched.
Note that if HiCache is not enabled,
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
"""
device_indices: torch.Tensor
last_device_node: Any
last_host_node: Any
host_hit_length: int = 0
class BasePrefixCache(ABC):
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
pass
@abstractmethod
def insert(self, **kwargs):
def cache_finished_req(self, req: Req, **kwargs):
pass
@abstractmethod
def cache_finished_req(self, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, **kwargs):
def cache_unfinished_req(self, req: Req, **kwargs):
pass
@abstractmethod
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
def pretty_print(self):
raise NotImplementedError()
def init_load_back(
self,
last_host_node: Any,
host_hit_length: int,
) -> Tuple[torch.Tensor, Any]:
"""
Preparing KV cache loading from host to device.
"""
raise NotImplementedError()
def ready_to_load_host_cache(self) -> Any:
"""
Notify the cache controller to start the KV cache loading
"""
raise NotImplementedError()
def check_hicache_events(self) -> Any:
"""
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
"""
raise NotImplementedError()
def take_events(self):
return []

View File

@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry:
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache):
def __init__(
self,
@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = True
def reset(self):
pass
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
return [], None
def match_prefix(self, **unused_kwargs) -> MatchResult:
return MatchResult(
device_indices=torch.empty((0,), dtype=torch.int64),
last_device_node=None,
last_host_node=None,
)
def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[
@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices
def insert(self):
raise NotImplementedError()
def evict(self, num_tokens: int):
pass

View File

@@ -7,6 +7,7 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
def init_load_back(
self,
last_node: TreeNode,
prefix_indices: torch.Tensor,
host_hit_length: int,
mem_quota: Optional[int] = None,
):
assert (
len(prefix_indices) == 0 or prefix_indices.is_cuda
), "indices of device kV caches should be on GPU"
_ = host_hit_length # unused, but kept for compatibility
if last_node.evicted:
loading_values = self.load_back(last_node, mem_quota)
if loading_values is not None:
prefix_indices = (
loading_values
if len(prefix_indices) == 0
else torch.cat([prefix_indices, loading_values])
)
logger.debug(
f"loading back {len(loading_values)} tokens for node {last_node.id}"
)
return loading_values, last_node
while last_node.evicted:
last_node = last_node.parent
return last_node, prefix_indices
return (
torch.empty((0,), dtype=torch.int64, device=self.device),
last_node,
)
def ready_to_load_cache(self):
def ready_to_load_host_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
def check_hicache_events(self):
self.writing_check()
self.loading_check()
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0:
if include_evicted:
return empty_value, self.root_node, self.root_node
else:
return empty_value, self.root_node
return MatchResult(
device_indices=empty_value,
last_device_node=self.root_node,
last_host_node=self.root_node,
host_hit_length=0,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
else:
value = empty_value
last_node_global = last_node
host_hit_length = 0
last_host_node = last_node
while last_node.evicted:
host_hit_length += len(last_node.host_value)
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_host_node,
host_hit_length=host_hit_length,
)
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.monotonic()

View File

@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
BlockStored,
KVCacheEvent,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
@@ -47,9 +46,9 @@ class TreeNode:
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent = None
self.key = None
self.value = None
self.parent: TreeNode = None
self.key: List[int] = None
self.value: Optional[torch.Tensor] = None
self.lock_ref = 0
self.last_access_time = time.monotonic()
@@ -57,7 +56,7 @@ class TreeNode:
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.host_value: Optional[torch.Tensor] = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
self.protected_size_ = 0
self._record_all_cleared_event()
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
than the last node's value.
"""
if self.disable or len(key) == 0:
return (
torch.empty(
return MatchResult(
device_indices=torch.empty(
(0,),
dtype=torch.int64,
device=self.device,
),
self.root_node,
last_device_node=self.root_node,
last_host_node=self.root_node,
)
if self.page_size != 1:
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return value, last_node
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
def insert(self, key: List, value=None):
if self.disable:
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],