Refactors radix cache for extra key support (#10317)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import (
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
@@ -457,6 +457,7 @@ class Req:
|
||||
vocab_size: Optional[int] = None,
|
||||
priority: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
extra_key: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -489,6 +490,14 @@ class Req:
|
||||
self.sampling_params = sampling_params
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||
if lora_id is not None:
|
||||
extra_key = (
|
||||
extra_key or ""
|
||||
) + lora_id # lora_id is concatenated to the extra key
|
||||
|
||||
self.extra_key = extra_key
|
||||
self.lora_id = lora_id
|
||||
|
||||
# Memory pool info
|
||||
@@ -679,26 +688,16 @@ class Req:
|
||||
):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
if isinstance(tree_cache, LoRARadixCache):
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix_with_lora_id(
|
||||
key=LoRAKey(
|
||||
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
||||
),
|
||||
)
|
||||
else:
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(),
|
||||
)
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=RadixKey(
|
||||
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
||||
),
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -175,10 +175,13 @@ class SchedulePolicy:
|
||||
|
||||
for r in waiting_queue:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
extra_key = r.extra_key
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
||||
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
||||
self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
@@ -191,7 +194,8 @@ class SchedulePolicy:
|
||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
||||
in_batch_matching_prefixes, _, _, _ = (
|
||||
self.waiting_queue_radix_tree.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
rid=r.rid,
|
||||
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
||||
)
|
||||
)
|
||||
if (
|
||||
@@ -202,7 +206,8 @@ class SchedulePolicy:
|
||||
else:
|
||||
# Insert with a dummy key
|
||||
self.waiting_queue_radix_tree.insert(
|
||||
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
||||
RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
||||
torch.empty(len(prefix_ids), dtype=torch.bool),
|
||||
)
|
||||
return temporary_deprioritized
|
||||
|
||||
|
||||
@@ -145,7 +145,6 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
@@ -719,19 +718,6 @@ class Scheduler(
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif self.enable_lora:
|
||||
assert (
|
||||
not self.enable_hierarchical_cache
|
||||
), "LoRA radix cache doesn't support hierarchical cache"
|
||||
assert (
|
||||
self.schedule_policy == "fcfs"
|
||||
), "LoRA radix cache only supports FCFS policy"
|
||||
self.tree_cache = LoRARadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
LMCRadixCache,
|
||||
|
||||
@@ -36,7 +36,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: Any, **kwargs) -> MatchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.metrics.collector import StorageMetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache):
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
RadixKey(
|
||||
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
||||
),
|
||||
written_indices,
|
||||
hash_value[: min_completed_tokens // self.page_size],
|
||||
)
|
||||
@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return True
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
def match_prefix(self, key: RadixKey, **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
def _insert_helper_host(
|
||||
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
||||
):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache):
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
value = []
|
||||
@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache):
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def insert(self, key: List, value, chunked=False):
|
||||
def insert(self, key: RadixKey, value=None, chunked=False):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache):
|
||||
for idx in range(0, len(key), self.page_size):
|
||||
new_node.hash_value.append(
|
||||
self.cache_controller.get_hash_str(
|
||||
key[idx : idx + self.page_size],
|
||||
key.token_ids[idx : idx + self.page_size],
|
||||
prior_hash=last_hash,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
else:
|
||||
Req = Any # Placeholder for Req type when not type checking
|
||||
|
||||
|
||||
class LoRAKey:
|
||||
|
||||
def __init__(self, lora_id: str, token_ids: List[int]):
|
||||
self.lora_id = (
|
||||
lora_id # lora_id of adaptor, should be hash value of adaptor path
|
||||
)
|
||||
self.token_ids = token_ids # token_ids of the key
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
def get_child_key(key: LoRAKey):
|
||||
# Here the key of children dict is the hash of lora_id + str(token_ids[0])
|
||||
# So the child key can be matched only when lora_id and token_ids[0] are the same
|
||||
if key.lora_id is None:
|
||||
return hash(str(key.token_ids[0]))
|
||||
else:
|
||||
return hash(key.lora_id + str(key.token_ids[0]))
|
||||
|
||||
|
||||
class LoRATreeNode:
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(LoRATreeNode)
|
||||
self.parent: LoRATreeNode = None
|
||||
self.key: LoRAKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
self.id = LoRATreeNode.counter if id is None else id
|
||||
LoRATreeNode.counter += 1
|
||||
|
||||
@property
|
||||
def evicted(self):
|
||||
return self.value is None
|
||||
|
||||
def __lt__(self, other: "LoRATreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match(key0: LoRAKey, key1: LoRAKey):
|
||||
if key0.lora_id != key1.lora_id:
|
||||
raise ValueError(
|
||||
f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
|
||||
)
|
||||
i = 0
|
||||
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
class LoRARadixCache(BasePrefixCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
):
|
||||
if page_size > 1:
|
||||
raise ValueError("LoRARadixCache currently only supports page_size = 1")
|
||||
|
||||
if token_to_kv_pool_allocator is None:
|
||||
raise ValueError(
|
||||
"token_to_kv_pool_allocator is required to run LoraRadixCache"
|
||||
)
|
||||
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
|
||||
self.key_match_fn = _key_match
|
||||
self.get_child_key_fn = get_child_key
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.root_node = LoRATreeNode()
|
||||
self.root_node.key = LoRAKey(lora_id="", token_ids=[])
|
||||
self.root_node.value = None
|
||||
self.evictable_size_ = 0
|
||||
self.protected_size_ = 0
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
raise ValueError(
|
||||
"LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
|
||||
)
|
||||
|
||||
def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the lora radix tree.
|
||||
Args:
|
||||
key: A LoRAKey to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
),
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
)
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: LoRAKey, value=None):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key.token_ids]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
|
||||
new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
|
||||
new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
print(f"#tokens: {self.total_size()}")
|
||||
|
||||
def total_size(self):
|
||||
return self._total_size_helper()
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x == self.root_node:
|
||||
break
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
self.token_to_kv_pool_allocator.free(x.value)
|
||||
num_evicted += len(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_lock_ref(self, node: LoRATreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
self.protected_size_ += len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.lock_ref += 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def dec_lock_ref(self, node: LoRATreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
self.protected_size_ -= len(node.value)
|
||||
delta += len(node.value)
|
||||
node.lock_ref -= 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def protected_size(self):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
def all_values_flatten(self):
|
||||
values = []
|
||||
|
||||
def _dfs_helper(node: LoRATreeNode):
|
||||
for _, child in node.children.items():
|
||||
values.append(child.value)
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.cat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
new_node = LoRATreeNode()
|
||||
key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
|
||||
key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
|
||||
new_node.children = {self.get_child_key_fn(key_split_2): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = key_split_1
|
||||
new_node.value = child.value[:split_len]
|
||||
child.parent = new_node
|
||||
child.key = key_split_2
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
total_prefix_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
total_prefix_length += prefix_len
|
||||
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
||||
value = value[prefix_len:]
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = LoRATreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return total_prefix_length
|
||||
|
||||
def _print_helper(self, node: LoRATreeNode, indent: int):
|
||||
"""Prints the radix tree in a human-readable format."""
|
||||
stack = [(node, indent)]
|
||||
while stack:
|
||||
current_node, current_indent = stack.pop()
|
||||
print(
|
||||
" " * current_indent,
|
||||
len(current_node.key),
|
||||
current_node.key.token_ids[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for key, child in current_node.children.items():
|
||||
stack.append((child, current_indent + 2))
|
||||
|
||||
assert key == self.get_child_key_fn(
|
||||
child.key
|
||||
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.evictable_size_ -= len(node.key)
|
||||
|
||||
def _total_size_helper(self):
|
||||
total_size = 0
|
||||
stack = [self.root_node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
total_size += len(current_node.value)
|
||||
for child in current_node.children.values():
|
||||
if child.evicted:
|
||||
continue
|
||||
stack.append(child)
|
||||
return total_size
|
||||
|
||||
def _collect_leaves(self):
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
if len(cur_node.children) == 0:
|
||||
ret_list.append(cur_node)
|
||||
else:
|
||||
stack.extend(cur_node.children.values())
|
||||
|
||||
return ret_list
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -41,6 +41,30 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class RadixKey:
|
||||
|
||||
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
|
||||
# token ids sequence
|
||||
self.token_ids = token_ids
|
||||
# extra key (e.g. lora_id, cache_salt)
|
||||
self.extra_key = extra_key
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.token_ids)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return iter(self.token_ids)
|
||||
|
||||
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
|
||||
if isinstance(idx, slice):
|
||||
return RadixKey(self.token_ids[idx], self.extra_key)
|
||||
return RadixKey([self.token_ids[idx]], self.extra_key)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
preview = self.token_ids[:10]
|
||||
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
|
||||
|
||||
|
||||
class TreeNode:
|
||||
|
||||
counter = 0
|
||||
@@ -48,7 +72,7 @@ class TreeNode:
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.key: RadixKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
@@ -94,27 +118,47 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
def _check_extra_key(key0: RadixKey, key1: RadixKey):
|
||||
if key0.extra_key != key1.extra_key:
|
||||
raise ValueError(
|
||||
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
|
||||
)
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
|
||||
_check_extra_key(key0, key1)
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
|
||||
_check_extra_key(key0, key1)
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def get_child_key(key: RadixKey, page_size: int = 1):
|
||||
if page_size == 1:
|
||||
plain_key = key.token_ids[0]
|
||||
else:
|
||||
plain_key = tuple(key.token_ids[:page_size])
|
||||
if key.extra_key is None:
|
||||
return plain_key
|
||||
else:
|
||||
return (key.extra_key, plain_key)
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -139,10 +183,10 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
self.get_child_key_fn = get_child_key
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
if eviction_policy.lower() == "lru":
|
||||
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
||||
@@ -158,7 +202,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
def reset(self):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.key = []
|
||||
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
|
||||
self.root_node.value = []
|
||||
self.root_node.host_value = []
|
||||
self.root_node.lock_ref = 1
|
||||
@@ -166,16 +210,43 @@ class RadixCache(BasePrefixCache):
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
"""Find the longest cached prefix of ``key`` in the radix tree.
|
||||
|
||||
The logical namespace for prefix matching is determined by both the
|
||||
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
|
||||
Entries that share identical leading token ids but have *different*
|
||||
``extra_key`` values are intentionally kept disjoint and never share
|
||||
prefix nodes. This is useful to:
|
||||
|
||||
* Isolate KV cache lines for different LoRA / adapter IDs.
|
||||
* Separate requests that intentionally should not share state (e.g.,
|
||||
different sampling salt, cache version, or retrieval augmentation
|
||||
context) by supplying a distinct ``extra_key``.
|
||||
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
key (RadixKey): The lookup key containing a list of token ids and an
|
||||
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
|
||||
length is internally truncated to a multiple of ``page_size``
|
||||
before matching. Passing an empty key returns an empty result
|
||||
with the root as the last node.
|
||||
**kwargs: Reserved for future extensions (ignored currently).
|
||||
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
|
||||
the concatenated KV cache indices corresponding to the longest
|
||||
cached prefix (may be length 0). ``last_device_node`` and
|
||||
``last_host_node`` (currently the same) are the tree node objects
|
||||
representing the terminal node of the matched prefix. This method
|
||||
may mutate internal structure by splitting an existing node if the
|
||||
match ends inside a stored segment.
|
||||
|
||||
Internal updates:
|
||||
* Refreshes access metadata (timestamps) used by the
|
||||
configured eviction strategy.
|
||||
* If the lookup ends inside a stored segment the node is split once
|
||||
to expose a precise boundary; this structural refinement improves
|
||||
subsequent match efficiency and does not duplicate data.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
@@ -203,12 +274,12 @@ class RadixCache(BasePrefixCache):
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None, chunked=False):
|
||||
def insert(self, key: RadixKey, value=None, chunked=False):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
@@ -238,7 +309,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len], page_aligned_kv_indices
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
@@ -270,14 +342,18 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
chunked=chunked,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
@@ -379,7 +455,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
@@ -404,7 +480,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
self._record_remove_event(child)
|
||||
new_node = TreeNode()
|
||||
@@ -423,7 +499,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
def _insert_helper(self, node: TreeNode, key: RadixKey, value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -464,7 +540,7 @@ class RadixCache(BasePrefixCache):
|
||||
print(
|
||||
" " * current_indent,
|
||||
len(current_node.key),
|
||||
current_node.key[:10],
|
||||
current_node.key.token_ids[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for key, child in current_node.children.items():
|
||||
@@ -516,11 +592,11 @@ class RadixCache(BasePrefixCache):
|
||||
last_page_start = (
|
||||
(len(node.parent.key) - 1) // self.page_size
|
||||
) * self.page_size
|
||||
parent_parent_tokens = node.parent.key[last_page_start:]
|
||||
parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
|
||||
parent_block_hash = hash(tuple(parent_parent_tokens))
|
||||
|
||||
for start in range(0, len(node.key), self.page_size):
|
||||
page_tokens = node.key[start : start + self.page_size]
|
||||
page_tokens = node.key.token_ids[start : start + self.page_size]
|
||||
if not page_tokens:
|
||||
continue
|
||||
|
||||
@@ -543,7 +619,7 @@ class RadixCache(BasePrefixCache):
|
||||
# One BlockRemoved per chunk.
|
||||
if self.enable_kv_cache_events:
|
||||
for start in range(0, len(node.key), self.page_size):
|
||||
page_tokens = node.key[start : start + self.page_size]
|
||||
page_tokens = node.key.token_ids[start : start + self.page_size]
|
||||
if not page_tokens:
|
||||
continue
|
||||
block_hash = hash(tuple(page_tokens))
|
||||
@@ -569,19 +645,12 @@ class RadixCache(BasePrefixCache):
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello_L.A.!")
|
||||
# tree.insert("Hello_world! Happy")
|
||||
# tree.insert("I love you!")
|
||||
# Example token id sequences (as lists of ints)
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
|
||||
tree.pretty_print()
|
||||
|
||||
# print(tree.match_prefix("I love you! aha"))
|
||||
|
||||
# def evict_callback(x):
|
||||
# print("evict", x)
|
||||
# return len(x)
|
||||
|
||||
# tree.evict(5, evict_callback)
|
||||
# tree.evict(10, evict_callback)
|
||||
# tree.pretty_print()
|
||||
print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
||||
TreeNodeCpp,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
raise NotImplementedError("Host cache is not supported yet")
|
||||
self.tree.reset()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
||||
self.tree.match_prefix(key)
|
||||
self.tree.match_prefix(key.token_ids)
|
||||
)
|
||||
return MatchResult(
|
||||
device_indices=self._merge_tensor(device_indices_vec),
|
||||
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
host_hit_length=host_indices_length,
|
||||
)
|
||||
|
||||
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
||||
def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
|
||||
"""
|
||||
Insert a key-value pair into the radix tree.
|
||||
Args:
|
||||
key (List[int]): The key to insert, represented as a list of integers.
|
||||
key (RadixKey): The key to insert, represented as a RadixKey.
|
||||
value (torch.Tensor): The value to associate with the key.
|
||||
Returns:
|
||||
int: Number of device indices that were already present in the tree before the insertion.
|
||||
"""
|
||||
ongoing_write, length = self.tree.writing_through(key, value)
|
||||
ongoing_write, length = self.tree.writing_through(key.token_ids, value)
|
||||
if self.cache_controller is None:
|
||||
assert len(ongoing_write) == 0, "Implementation error"
|
||||
return length
|
||||
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
|
||||
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
||||
# The prefix indices need to updated to reuse the kv indices in the pool
|
||||
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
||||
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
||||
RadixKey(token_ids, req.extra_key).token_ids
|
||||
)
|
||||
new_indices = self._merge_tensor(new_indices_vec)
|
||||
assert new_prefix_len <= len(new_indices)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
|
||||
try:
|
||||
from lmcache.integration.sglang.sglang_adapter import (
|
||||
@@ -131,7 +131,7 @@ class LMCRadixCache(RadixCache):
|
||||
with self._node_lock:
|
||||
self._in_flight_nodes.clear()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
|
||||
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
||||
|
||||
Reuses the base matching logic to obtain (value, last_node). If there
|
||||
@@ -178,7 +178,7 @@ class LMCRadixCache(RadixCache):
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
num_retrieved = self.lmcache_connector.start_load_kv(
|
||||
LoadMetadata(
|
||||
token_ids=key, # full page-aligned key
|
||||
token_ids=key.token_ids, # full page-aligned key
|
||||
slot_mapping=slot_mapping,
|
||||
offset=value.numel() - prefix_pad, # LMCache offset convention
|
||||
)
|
||||
@@ -227,7 +227,7 @@ class LMCRadixCache(RadixCache):
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
||||
_, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
|
||||
assert new_last_node is not None
|
||||
|
||||
self.inc_lock_ref(new_last_node)
|
||||
@@ -277,6 +277,8 @@ if __name__ == "__main__":
|
||||
rank=0,
|
||||
tp_group=None,
|
||||
)
|
||||
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||
cache.insert(
|
||||
RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
|
||||
)
|
||||
cache.pretty_print()
|
||||
|
||||
@@ -30,6 +30,12 @@ import torch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import (
|
||||
RadixKey,
|
||||
_key_match_page_size1,
|
||||
_key_match_paged,
|
||||
get_child_key,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -47,7 +53,7 @@ class TreeNode:
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.key: RadixKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
# swa_tombstone is used to indicate the kv indices have been freed for swa layers
|
||||
self.swa_tombstone = False
|
||||
@@ -87,27 +93,6 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def gen_swa_uuid() -> int:
|
||||
TreeNode.swa_uuid_counter += 1
|
||||
return TreeNode.swa_uuid_counter
|
||||
@@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
self.get_child_key_fn = get_child_key
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.reset()
|
||||
@@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
|
||||
self.full_lru_list = LRUList(swa=False)
|
||||
self.swa_lru_list = LRUList(swa=True)
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
key: A RadixKey contains token IDs to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
@@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
|
||||
def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
||||
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
||||
|
||||
def cache_finished_req(self, req: Req) -> None:
|
||||
@@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
# insert the token_ids and kv_indices into the radix tree
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len],
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
)
|
||||
@@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key)
|
||||
)
|
||||
assert len(req.prefix_indices) <= len(
|
||||
new_indices
|
||||
), f"{req.prefix_indices=}, {new_indices=}"
|
||||
@@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
|
||||
def _match_prefix_helper(
|
||||
self, key: RadixKey
|
||||
) -> Tuple[List[torch.Tensor], TreeNode]:
|
||||
"""
|
||||
SWA prefix matching helper. It factors in the sliding window size such that
|
||||
the matched node is guaranteed to either 1. connected to root without swa tombstone,
|
||||
@@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
return value[:best_value_len], best_last_node
|
||||
|
||||
def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
@@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
return new_node
|
||||
|
||||
def _insert_helper(
|
||||
self, node: TreeNode, key: List, value, update_kv_after_len: int
|
||||
self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
|
||||
) -> int:
|
||||
# Update the last access time from root to leaf, so that
|
||||
# swa will tombstone the node closer to root first
|
||||
|
||||
@@ -99,6 +99,7 @@ suites = {
|
||||
TestFile("test_priority_scheduling.py", 100),
|
||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||
TestFile("test_radix_attention.py", 105),
|
||||
TestFile("test_radix_cache_unit.py", 5),
|
||||
TestFile("test_regex_constrained.py", 64),
|
||||
TestFile("test_reasoning_parser.py", 5),
|
||||
TestFile("test_retract_decode.py", 54),
|
||||
|
||||
597
test/srt/test_radix_cache_unit.py
Normal file
597
test/srt/test_radix_cache_unit.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
Unit tests for the RadixCache implementation.
|
||||
|
||||
This module tests the core functionality of RadixCache, RadixKey, and TreeNode
|
||||
following SGLang testing patterns.
|
||||
|
||||
Test Coverage:
|
||||
- RadixKey: token ID management, slicing, iteration, representation
|
||||
- TreeNode: node properties, reference counting, hash values
|
||||
- RadixCache: insert/match operations, eviction, page alignment, error handling
|
||||
- Cache events and request handling
|
||||
- Boundary conditions with parameterized testing
|
||||
|
||||
Usage:
|
||||
python test_radix_cache_unit.py
|
||||
python -m pytest test_radix_cache_unit.py -v
|
||||
python -m pytest test_radix_cache_unit.py::TestRadixCache::test_insert_basic
|
||||
"""
|
||||
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.kv_events import BlockRemoved, BlockStored
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
|
||||
# Test constants
|
||||
DEFAULT_PAGE_SIZE = 4
|
||||
|
||||
|
||||
class TestRadixKey(unittest.TestCase):
|
||||
"""Test cases for RadixKey class."""
|
||||
|
||||
def test_init_basic(self):
|
||||
"""Test basic initialization of RadixKey."""
|
||||
token_ids = [1, 2, 3, 4]
|
||||
key = RadixKey(token_ids)
|
||||
self.assertEqual(key.token_ids, token_ids)
|
||||
self.assertIsNone(key.extra_key)
|
||||
|
||||
def test_init_with_extra_key(self):
|
||||
"""Test initialization with extra_key."""
|
||||
token_ids = [1, 2, 3]
|
||||
extra_key = "test_key"
|
||||
key = RadixKey(token_ids, extra_key)
|
||||
self.assertEqual(key.token_ids, token_ids)
|
||||
self.assertEqual(key.extra_key, extra_key)
|
||||
|
||||
def test_len(self):
|
||||
"""Test __len__ method."""
|
||||
key = RadixKey([1, 2, 3])
|
||||
self.assertEqual(len(key), 3)
|
||||
|
||||
empty_key = RadixKey([])
|
||||
self.assertEqual(len(empty_key), 0)
|
||||
|
||||
def test_iter(self):
|
||||
"""Test __iter__ method."""
|
||||
token_ids = [1, 2, 3, 4]
|
||||
key = RadixKey(token_ids)
|
||||
self.assertEqual(list(key), token_ids)
|
||||
|
||||
def test_len_and_iter(self):
|
||||
"""Test __len__ and __iter__ methods."""
|
||||
test_cases = [
|
||||
([1, 2, 3], 3),
|
||||
([], 0),
|
||||
([42], 1),
|
||||
]
|
||||
|
||||
for tokens, expected in test_cases:
|
||||
with self.subTest(tokens=tokens):
|
||||
key = RadixKey(tokens)
|
||||
self.assertEqual(len(key), expected)
|
||||
self.assertEqual(list(key), tokens)
|
||||
|
||||
def test_getitem_int(self):
|
||||
"""Test __getitem__ with int index."""
|
||||
test_cases = [
|
||||
([10, 20, 30], 0, [10]),
|
||||
([10, 20, 30], -1, [30]),
|
||||
([10, 20, 30], 2, [30]),
|
||||
]
|
||||
|
||||
for tokens, index, expected in test_cases:
|
||||
with self.subTest(tokens=tokens, index=index):
|
||||
key = RadixKey(tokens)
|
||||
result = key[index]
|
||||
self.assertIsInstance(result, RadixKey)
|
||||
self.assertEqual(result.token_ids, expected)
|
||||
|
||||
def test_getitem_slice(self):
|
||||
"""Test __getitem__ with slice and edge cases."""
|
||||
key = RadixKey([1, 2, 3, 4, 5], "extra")
|
||||
|
||||
# Basic slice
|
||||
sliced = key[1:4]
|
||||
self.assertIsInstance(sliced, RadixKey)
|
||||
self.assertEqual(sliced.token_ids, [2, 3, 4])
|
||||
self.assertEqual(sliced.extra_key, "extra")
|
||||
|
||||
# Edge cases
|
||||
self.assertEqual(key[2:2].token_ids, []) # Empty slice
|
||||
self.assertEqual(key[:].token_ids, [1, 2, 3, 4, 5]) # Full slice
|
||||
|
||||
def test_getitem_invalid_index(self):
|
||||
"""Test __getitem__ with invalid indices."""
|
||||
key = RadixKey([1, 2, 3])
|
||||
with self.assertRaises(IndexError):
|
||||
_ = key[10] # Out of bounds
|
||||
|
||||
def test_repr(self):
|
||||
"""Test __repr__ method."""
|
||||
key = RadixKey([1, 2, 3], "test")
|
||||
repr_str = repr(key)
|
||||
self.assertIn("RadixKey", repr_str)
|
||||
self.assertIn("extra_key='test'", repr_str)
|
||||
self.assertIn("[1, 2, 3]", repr_str)
|
||||
|
||||
def test_repr_long_token_ids(self):
|
||||
"""Test __repr__ with long token_ids."""
|
||||
long_tokens = list(range(15))
|
||||
key = RadixKey(long_tokens)
|
||||
repr_str = repr(key)
|
||||
self.assertIn("...", repr_str) # Should be truncated
|
||||
|
||||
|
||||
class TestTreeNode(unittest.TestCase):
|
||||
"""Test cases for TreeNode class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset the counter before each test."""
|
||||
TreeNode.counter = 0
|
||||
|
||||
def test_init_basic(self):
|
||||
"""Test basic initialization of TreeNode."""
|
||||
node = TreeNode()
|
||||
self.assertEqual(node.id, 0)
|
||||
self.assertEqual(len(node.children), 0)
|
||||
self.assertIsNone(node.parent)
|
||||
self.assertIsNone(node.key)
|
||||
self.assertIsNone(node.value)
|
||||
self.assertEqual(node.lock_ref, 0)
|
||||
self.assertEqual(node.hit_count, 0)
|
||||
self.assertEqual(node.host_ref_counter, 0)
|
||||
self.assertIsNone(node.host_value)
|
||||
self.assertIsNone(node.hash_value)
|
||||
|
||||
def test_init_with_id(self):
|
||||
"""Test initialization with custom ID."""
|
||||
node = TreeNode(id=42)
|
||||
self.assertEqual(node.id, 42)
|
||||
node2 = TreeNode()
|
||||
self.assertEqual(node2.id, 1) # Counter was incremented
|
||||
|
||||
def test_counter_increment(self):
|
||||
"""Test that counter increments properly."""
|
||||
node1 = TreeNode()
|
||||
node2 = TreeNode()
|
||||
self.assertEqual(node1.id, 0)
|
||||
self.assertEqual(node2.id, 1)
|
||||
|
||||
def test_evicted_backuped_properties(self):
|
||||
"""Test evicted and backuped properties."""
|
||||
test_cases = [
|
||||
(False, False, True, False),
|
||||
(True, False, False, False),
|
||||
(True, True, False, True),
|
||||
(False, True, True, True),
|
||||
]
|
||||
|
||||
for (
|
||||
has_value,
|
||||
has_host_value,
|
||||
expected_evicted,
|
||||
expected_backuped,
|
||||
) in test_cases:
|
||||
with self.subTest(has_value=has_value, has_host_value=has_host_value):
|
||||
node = TreeNode()
|
||||
|
||||
if has_value:
|
||||
node.value = torch.tensor([1, 2, 3])
|
||||
if has_host_value:
|
||||
node.host_value = torch.tensor([4, 5, 6])
|
||||
|
||||
self.assertEqual(node.evicted, expected_evicted)
|
||||
self.assertEqual(node.backuped, expected_backuped)
|
||||
|
||||
def test_protect_release_host(self):
|
||||
"""Test protect_host and release_host methods."""
|
||||
node = TreeNode()
|
||||
self.assertEqual(node.host_ref_counter, 0)
|
||||
|
||||
node.protect_host()
|
||||
self.assertEqual(node.host_ref_counter, 1)
|
||||
|
||||
node.release_host()
|
||||
self.assertEqual(node.host_ref_counter, 0)
|
||||
|
||||
# Test error case
|
||||
with self.assertRaises(RuntimeError):
|
||||
node.release_host()
|
||||
|
||||
def test_get_last_hash_value(self):
|
||||
"""Test get_last_hash_value method."""
|
||||
node = TreeNode()
|
||||
self.assertIsNone(node.get_last_hash_value())
|
||||
|
||||
node.hash_value = ["hash1", "hash2", "hash3"]
|
||||
self.assertEqual(node.get_last_hash_value(), "hash3")
|
||||
|
||||
def test_lt_comparison(self):
|
||||
"""Test less than comparison based on last_access_time."""
|
||||
node1 = TreeNode()
|
||||
time.sleep(0.001) # Small delay to ensure different timestamps
|
||||
node2 = TreeNode()
|
||||
|
||||
self.assertTrue(node1 < node2)
|
||||
self.assertFalse(node2 < node1)
|
||||
|
||||
|
||||
class TestRadixCache(unittest.TestCase):
|
||||
"""Test cases for RadixCache class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
TreeNode.counter = 0
|
||||
|
||||
def test_init_variations(self):
|
||||
"""Test cache initialization with different parameters."""
|
||||
test_cases = [
|
||||
(1, False, False),
|
||||
(4, False, True),
|
||||
(1, True, False),
|
||||
]
|
||||
|
||||
for page_size, disable, enable_events in test_cases:
|
||||
with self.subTest(
|
||||
page_size=page_size, disable=disable, enable_events=enable_events
|
||||
):
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=page_size,
|
||||
disable=disable,
|
||||
enable_kv_cache_events=enable_events,
|
||||
)
|
||||
|
||||
self.assertEqual(cache.page_size, page_size)
|
||||
self.assertEqual(cache.disable, disable)
|
||||
self.assertEqual(cache.enable_kv_cache_events, enable_events)
|
||||
self.assertEqual(cache.device, torch.device("cpu"))
|
||||
self.assertIsNotNone(cache.root_node)
|
||||
self.assertEqual(len(cache.root_node.key), 0)
|
||||
|
||||
def test_reset(self):
|
||||
"""Test reset method."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
# Insert some data
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
||||
self.assertGreater(cache.total_size(), 0)
|
||||
|
||||
# Reset
|
||||
cache.reset()
|
||||
self.assertEqual(cache.total_size(), 0)
|
||||
self.assertEqual(cache.evictable_size(), 0)
|
||||
self.assertEqual(cache.protected_size(), 0)
|
||||
|
||||
def test_insert_and_match_basic(self):
|
||||
"""Test basic insert and match operations."""
|
||||
for disable_cache in [False, True]:
|
||||
with self.subTest(disable_cache=disable_cache):
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=disable_cache,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3])
|
||||
value = torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
if disable_cache:
|
||||
self.assertEqual(prefix_len, 0)
|
||||
self.assertEqual(cache.total_size(), 0)
|
||||
continue
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(cache.total_size(), 3)
|
||||
self.assertEqual(cache.evictable_size(), 3)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3]))
|
||||
self.assertEqual(len(result.device_indices), 3)
|
||||
torch.testing.assert_close(result.device_indices, value)
|
||||
|
||||
# Test partial match
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 2)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_with_none_value(self):
|
||||
"""Test insert with None value (should use token_ids as list)."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3])
|
||||
prefix_len = cache.insert(key, None)
|
||||
|
||||
# When None is passed, it should create value from token_ids
|
||||
self.assertEqual(prefix_len, 0)
|
||||
self.assertEqual(cache.total_size(), 3)
|
||||
|
||||
def test_total_size(self):
|
||||
"""Test total_size calculation."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
self.assertEqual(cache.total_size(), 0)
|
||||
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
||||
self.assertEqual(cache.total_size(), 3)
|
||||
|
||||
cache.insert(RadixKey([4, 5]), torch.tensor([40, 50], dtype=torch.int64))
|
||||
self.assertEqual(cache.total_size(), 5)
|
||||
|
||||
def test_kv_cache_events(self):
|
||||
"""Test KV cache events functionality."""
|
||||
test_cases = [
|
||||
(1, True),
|
||||
(2, True),
|
||||
(1, False),
|
||||
]
|
||||
|
||||
for page_size, enable_events in test_cases:
|
||||
with self.subTest(page_size=page_size, enable_events=enable_events):
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=page_size,
|
||||
enable_kv_cache_events=enable_events,
|
||||
)
|
||||
|
||||
# Insert data
|
||||
cache.insert(RadixKey([1, 2, 3, 4, 5]), None)
|
||||
|
||||
# Take events
|
||||
events = cache.take_events()
|
||||
|
||||
if enable_events:
|
||||
self.assertGreater(len(events), 0)
|
||||
# Verify events include BlockStored events (there might be other event types)
|
||||
block_stored_events = [
|
||||
e for e in events if isinstance(e, BlockStored)
|
||||
]
|
||||
self.assertGreater(len(block_stored_events), 0)
|
||||
for event in block_stored_events:
|
||||
self.assertLessEqual(len(event.token_ids), page_size)
|
||||
else:
|
||||
self.assertEqual(len(events), 0)
|
||||
|
||||
def test_kv_cache_events_with_eviction(self):
|
||||
"""Test KV cache events include removal events."""
|
||||
mock_allocator = unittest.mock.Mock()
|
||||
mock_allocator.device = torch.device("cpu")
|
||||
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=mock_allocator,
|
||||
page_size=1,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
# Insert and then evict data
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
||||
cache.evict(3)
|
||||
|
||||
# Take events - should include both store and remove events
|
||||
events = cache.take_events()
|
||||
self.assertGreater(len(events), 0)
|
||||
|
||||
# Check event types
|
||||
event_types = [type(event).__name__ for event in events]
|
||||
self.assertIn("BlockStored", event_types)
|
||||
|
||||
# Verify BlockRemoved event content
|
||||
remove_events = [e for e in events if isinstance(e, BlockRemoved)]
|
||||
for event in remove_events:
|
||||
self.assertGreater(len(event.block_hashes), 0)
|
||||
|
||||
def test_extra_key_isolation(self):
|
||||
"""Test that keys with different extra_key values are isolated."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
# Insert same token sequence with different extra keys
|
||||
cache.insert(
|
||||
RadixKey([1, 2, 3], "key1"), torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
)
|
||||
cache.insert(
|
||||
RadixKey([1, 2, 3], "key2"), torch.tensor([40, 50, 60], dtype=torch.int64)
|
||||
)
|
||||
cache.insert(
|
||||
RadixKey([1, 2, 3], None), torch.tensor([70, 80, 90], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Keys with different extra_key should not match each other
|
||||
result1 = cache.match_prefix(RadixKey([1, 2, 3], "key1"))
|
||||
result2 = cache.match_prefix(RadixKey([1, 2, 3], "key2"))
|
||||
result3 = cache.match_prefix(RadixKey([1, 2, 3], None))
|
||||
result4 = cache.match_prefix(RadixKey([1, 2, 3], "nonexistent"))
|
||||
|
||||
# Each should match only its own data
|
||||
self.assertEqual(len(result1.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result1.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
)
|
||||
|
||||
self.assertEqual(len(result2.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result2.device_indices, torch.tensor([40, 50, 60], dtype=torch.int64)
|
||||
)
|
||||
|
||||
self.assertEqual(len(result3.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result3.device_indices, torch.tensor([70, 80, 90], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Non-existent extra_key should not match
|
||||
self.assertEqual(len(result4.device_indices), 0)
|
||||
|
||||
def test_lock_ref_operations(self):
|
||||
"""Test lock reference counting operations."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
# Insert sequence
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
||||
|
||||
# Get node
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3]))
|
||||
node = result.last_device_node
|
||||
|
||||
initial_evictable = cache.evictable_size()
|
||||
initial_protected = cache.protected_size()
|
||||
|
||||
# Lock the node
|
||||
cache.inc_lock_ref(node)
|
||||
self.assertEqual(cache.protected_size(), initial_protected + 3)
|
||||
self.assertEqual(cache.evictable_size(), initial_evictable - 3)
|
||||
|
||||
# Unlock the node
|
||||
cache.dec_lock_ref(node)
|
||||
self.assertEqual(cache.protected_size(), initial_protected)
|
||||
self.assertEqual(cache.evictable_size(), initial_evictable)
|
||||
|
||||
def test_evict_functionality(self):
|
||||
"""Test eviction functionality."""
|
||||
mock_allocator = unittest.mock.Mock()
|
||||
mock_allocator.device = torch.device("cpu")
|
||||
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=mock_allocator,
|
||||
page_size=1,
|
||||
)
|
||||
|
||||
# Insert sequences
|
||||
cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64))
|
||||
cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64))
|
||||
|
||||
initial_size = cache.total_size()
|
||||
|
||||
# Evict some tokens
|
||||
cache.evict(2)
|
||||
|
||||
# Should have called free and reduced size
|
||||
mock_allocator.free.assert_called()
|
||||
self.assertLess(cache.total_size(), initial_size)
|
||||
|
||||
def test_page_alignment_boundary(self):
|
||||
"""Test page alignment with different sizes."""
|
||||
test_cases = [
|
||||
(1, 5),
|
||||
(2, 5),
|
||||
(4, 6),
|
||||
]
|
||||
|
||||
for page_size, sequence_length in test_cases:
|
||||
with self.subTest(page_size=page_size, sequence_length=sequence_length):
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
tokens = list(range(sequence_length))
|
||||
cache.insert(RadixKey(tokens), torch.tensor(tokens, dtype=torch.int64))
|
||||
|
||||
result = cache.match_prefix(RadixKey(tokens))
|
||||
self.assertGreater(len(result.device_indices), 0)
|
||||
|
||||
# Match length should be page-aligned
|
||||
match_len = len(result.device_indices)
|
||||
self.assertEqual(match_len % page_size, 0)
|
||||
|
||||
def test_pretty_print_basic(self):
|
||||
"""Test pretty_print produces output."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
||||
|
||||
# Just test that it doesn't crash
|
||||
try:
|
||||
cache.pretty_print()
|
||||
except Exception as e:
|
||||
self.fail(f"pretty_print raised an exception: {e}")
|
||||
|
||||
def test_all_values_flatten(self):
|
||||
"""Test all_values_flatten method."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
||||
)
|
||||
|
||||
cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64))
|
||||
cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64))
|
||||
|
||||
all_values = cache.all_values_flatten()
|
||||
self.assertEqual(len(all_values), 4)
|
||||
# Values should contain all inserted values (order may vary)
|
||||
values_set = set(all_values.tolist())
|
||||
self.assertEqual(values_set, {10, 20, 30, 40})
|
||||
|
||||
def test_advanced_prefix_match_with_node_splits(self):
|
||||
"""Advanced prefix matching: splits inside nodes and across pages."""
|
||||
for page_size in [1, 2]:
|
||||
with self.subTest(page_size=page_size):
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Insert a long sequence that will be split later.
|
||||
seq1 = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
val1 = torch.tensor([x * 10 for x in seq1], dtype=torch.int64)
|
||||
cache.insert(RadixKey(seq1), val1)
|
||||
|
||||
# Insert a diverging branch to create an internal node on the path.
|
||||
seq2 = [1, 2, 9, 10]
|
||||
val2 = torch.tensor([x * 10 for x in seq2], dtype=torch.int64)
|
||||
cache.insert(RadixKey(seq2), val2)
|
||||
print(cache.pretty_print())
|
||||
|
||||
baseline_total = cache.total_size()
|
||||
expected_total = 10 # 8 + 2
|
||||
self.assertEqual(baseline_total, expected_total)
|
||||
|
||||
# Match that causes a split inside an existing node:
|
||||
# take first 4 tokens of seq1, then diverge.
|
||||
query1 = [1, 2, 3, 4, 999, 1000]
|
||||
result1 = cache.match_prefix(RadixKey(query1))
|
||||
torch.testing.assert_close(result1.device_indices, val1[:4])
|
||||
# No data change after structural split during matching.
|
||||
self.assertEqual(cache.total_size(), baseline_total)
|
||||
|
||||
# Full match of the long sequence still returns the full indices.
|
||||
result_full = cache.match_prefix(RadixKey(seq1))
|
||||
torch.testing.assert_close(result_full.device_indices, val1)
|
||||
|
||||
# Another split deeper on the path (after matching 6 tokens, then diverge).
|
||||
query2 = [1, 2, 3, 4, 5, 6, 777, 888]
|
||||
result2 = cache.match_prefix(RadixKey(query2))
|
||||
torch.testing.assert_close(result2.device_indices, val1[:6])
|
||||
self.assertEqual(cache.total_size(), baseline_total)
|
||||
|
||||
# Matching the short diverging branch should return exactly its indices.
|
||||
result_branch = cache.match_prefix(RadixKey(seq2))
|
||||
torch.testing.assert_close(result_branch.device_indices, val2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -4,7 +4,8 @@ import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import SWARadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
|
||||
|
||||
class TestSWA(unittest.TestCase):
|
||||
@@ -19,7 +20,7 @@ class TestSWA(unittest.TestCase):
|
||||
def test_swa_memory_pool(self):
|
||||
size = 16
|
||||
size_swa = 16
|
||||
num_head = 8
|
||||
head_num = 8
|
||||
head_dim = 128
|
||||
num_layers = 48
|
||||
global_interval = 4
|
||||
@@ -34,14 +35,20 @@ class TestSWA(unittest.TestCase):
|
||||
size=size,
|
||||
size_swa=size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
)
|
||||
alloc = SWATokenToKVPoolAllocator(
|
||||
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
|
||||
size=size,
|
||||
size_swa=size_swa,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
kvcache=pool,
|
||||
need_sort=False,
|
||||
)
|
||||
assert alloc.available_size() == size + size_swa
|
||||
index = alloc.alloc(1)
|
||||
@@ -57,7 +64,7 @@ class TestSWA(unittest.TestCase):
|
||||
kv_size = 128
|
||||
kv_size_swa = 64
|
||||
sliding_window_size = 4
|
||||
num_head = 8
|
||||
head_num = 8
|
||||
head_dim = 128
|
||||
num_layers = 48
|
||||
global_interval = 4
|
||||
@@ -80,10 +87,11 @@ class TestSWA(unittest.TestCase):
|
||||
size=kv_size,
|
||||
size_swa=kv_size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
)
|
||||
# setup token to kv pool allocator
|
||||
@@ -93,6 +101,7 @@ class TestSWA(unittest.TestCase):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
kvcache=kv_pool,
|
||||
need_sort=False,
|
||||
)
|
||||
# setup radix cache
|
||||
tree = SWARadixCache(
|
||||
@@ -112,7 +121,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req1_token_ids, req1_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
|
||||
print(
|
||||
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -121,7 +130,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req2_token_ids, req2_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
|
||||
print(
|
||||
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -130,7 +139,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req3_token_ids, req3_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
|
||||
print(
|
||||
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -139,7 +148,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req4_token_ids, req4_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
|
||||
print(
|
||||
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -161,21 +170,23 @@ class TestSWA(unittest.TestCase):
|
||||
tree.pretty_print()
|
||||
|
||||
req5_token_ids = [1, 2, 3, 4, 5]
|
||||
kv_indices, last_node = tree.match_prefix(req5_token_ids)
|
||||
result = tree.match_prefix(RadixKey(req5_token_ids))
|
||||
kv_indices, last_node = result.device_indices, result.last_device_node
|
||||
print(
|
||||
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
|
||||
)
|
||||
assert len(kv_indices) == 0
|
||||
|
||||
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
|
||||
kv_indices, last_node = tree.match_prefix(req6_token_ids)
|
||||
result = tree.match_prefix(RadixKey(req6_token_ids))
|
||||
kv_indices, last_node = result.device_indices, result.last_device_node
|
||||
print(
|
||||
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
|
||||
)
|
||||
assert len(kv_indices) == 7
|
||||
assert len(last_node.key) == 2
|
||||
assert last_node.key[0] == 60
|
||||
assert last_node.key[1] == 70
|
||||
assert last_node.key.token_ids[0] == 60
|
||||
assert last_node.key.token_ids[1] == 70
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user