[Refactor] Clean up radix cache related API (#7303)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -1,5 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
else:
|
||||
Req = Any # Placeholder for Req type when not type checking
|
||||
|
||||
|
||||
class MatchResult(NamedTuple):
|
||||
"""Result of a prefix match operation.
|
||||
|
||||
Attributes:
|
||||
device_indices : Indices of the KV cache on the device matched by common prefix.
|
||||
last_device_node: The last TreeNode on the device that was matched.
|
||||
last_host_node : The last TreeNode on the host that was matched.
|
||||
Note that if HiCache is not enabled,
|
||||
this **must** be the same as `last_device_node`.
|
||||
host_hit_length : Length of the KV cache hit on the host, if applicable.
|
||||
0 if HiCache is not enabled.
|
||||
"""
|
||||
|
||||
device_indices: torch.Tensor
|
||||
last_device_node: Any
|
||||
last_host_node: Any
|
||||
host_hit_length: int = 0
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, **kwargs):
|
||||
def cache_finished_req(self, req: Req, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_finished_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_unfinished_req(self, **kwargs):
|
||||
def cache_unfinished_req(self, req: Req, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_host_node: Any,
|
||||
host_hit_length: int,
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""
|
||||
Preparing KV cache loading from host to device.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def ready_to_load_host_cache(self) -> Any:
|
||||
"""
|
||||
Notify the cache controller to start the KV cache loading
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_hicache_events(self) -> Any:
|
||||
"""
|
||||
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def take_events(self):
|
||||
return []
|
||||
|
||||
@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
def __init__(self, rid: str, value: torch.Tensor):
|
||||
self.rid = rid
|
||||
self.value = value
|
||||
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = True
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
||||
return [], None
|
||||
def match_prefix(self, **unused_kwargs) -> MatchResult:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty((0,), dtype=torch.int64),
|
||||
last_device_node=None,
|
||||
last_host_node=None,
|
||||
)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
prefix_indices: torch.Tensor,
|
||||
host_hit_length: int,
|
||||
mem_quota: Optional[int] = None,
|
||||
):
|
||||
assert (
|
||||
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
||||
), "indices of device kV caches should be on GPU"
|
||||
_ = host_hit_length # unused, but kept for compatibility
|
||||
if last_node.evicted:
|
||||
loading_values = self.load_back(last_node, mem_quota)
|
||||
if loading_values is not None:
|
||||
prefix_indices = (
|
||||
loading_values
|
||||
if len(prefix_indices) == 0
|
||||
else torch.cat([prefix_indices, loading_values])
|
||||
)
|
||||
logger.debug(
|
||||
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
||||
)
|
||||
return loading_values, last_node
|
||||
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
return last_node, prefix_indices
|
||||
return (
|
||||
torch.empty((0,), dtype=torch.int64, device=self.device),
|
||||
last_node,
|
||||
)
|
||||
|
||||
def ready_to_load_cache(self):
|
||||
def ready_to_load_host_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
if include_evicted:
|
||||
return empty_value, self.root_node, self.root_node
|
||||
else:
|
||||
return empty_value, self.root_node
|
||||
return MatchResult(
|
||||
device_indices=empty_value,
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
host_hit_length=0,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
|
||||
else:
|
||||
value = empty_value
|
||||
|
||||
last_node_global = last_node
|
||||
host_hit_length = 0
|
||||
last_host_node = last_node
|
||||
while last_node.evicted:
|
||||
host_hit_length += len(last_node.host_value)
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_host_node,
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,9 +46,9 @@ class TreeNode:
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.key = None
|
||||
self.value = None
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
@@ -57,7 +56,7 @@ class TreeNode:
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return (
|
||||
torch.empty(
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
),
|
||||
self.root_node,
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
return value, last_node
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None):
|
||||
if self.disable:
|
||||
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
|
||||
Reference in New Issue
Block a user