diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 06dea43d4..d870f969a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import ( ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache -from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode @@ -457,6 +457,7 @@ class Req: vocab_size: Optional[int] = None, priority: Optional[int] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None, + extra_key: Optional[str] = None, ): # Input and output info self.rid = rid @@ -489,6 +490,14 @@ class Req: self.sampling_params = sampling_params self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states + + # extra key for classifying the request (e.g. lora_id, cache_salt) + if lora_id is not None: + extra_key = ( + extra_key or "" + ) + lora_id # lora_id is concatenated to the extra key + + self.extra_key = extra_key self.lora_id = lora_id # Memory pool info @@ -679,26 +688,16 @@ class Req: ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: - if isinstance(tree_cache, LoRARadixCache): - ( - self.prefix_indices, - self.last_node, - self.last_host_node, - self.host_hit_length, - ) = tree_cache.match_prefix_with_lora_id( - key=LoRAKey( - lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids() - ), - ) - else: - ( - self.prefix_indices, - self.last_node, - self.last_host_node, - self.host_hit_length, - ) = tree_cache.match_prefix( - key=self.adjust_max_prefix_ids(), - ) + ( + self.prefix_indices, + self.last_node, + self.last_host_node, + self.host_hit_length, + ) = tree_cache.match_prefix( + key=RadixKey( + token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key + ), + ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index a59dffd75..755ac29c8 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -27,7 +27,7 @@ import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -175,10 +175,13 @@ class SchedulePolicy: for r in waiting_queue: prefix_ids = r.adjust_max_prefix_ids() + extra_key = r.extra_key # NOTE: the prefix_indices must always be aligned with last_node r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = ( - self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids) + self.tree_cache.match_prefix( + rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) + ) ) # NOTE(sang): This logic is for in-batch prefix caching; @@ -191,7 +194,8 @@ class SchedulePolicy: if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: in_batch_matching_prefixes, _, _, _ = ( self.waiting_queue_radix_tree.match_prefix( - rid=r.rid, key=prefix_ids + rid=r.rid, + key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), ) ) if ( @@ -202,7 +206,8 @@ class SchedulePolicy: else: # Insert with a dummy key self.waiting_queue_radix_tree.insert( - prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) + RadixKey(token_ids=prefix_ids, extra_key=extra_key), + torch.empty(len(prefix_ids), dtype=torch.bool), ) return temporary_deprioritized diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 44dbc7d54..c83f43122 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -145,7 +145,6 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache -from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors @@ -719,19 +718,6 @@ class Scheduler( page_size=self.page_size, disable=server_args.disable_radix_cache, ) - elif self.enable_lora: - assert ( - not self.enable_hierarchical_cache - ), "LoRA radix cache doesn't support hierarchical cache" - assert ( - self.schedule_policy == "fcfs" - ), "LoRA radix cache only supports FCFS policy" - self.tree_cache = LoRARadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, - disable=server_args.disable_radix_cache, - ) elif server_args.enable_lmcache: from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( LMCRadixCache, diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 4fdd04b72..7c5c7246e 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -36,7 +36,7 @@ class BasePrefixCache(ABC): pass @abstractmethod - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + def match_prefix(self, key: Any, **kwargs) -> MatchResult: pass @abstractmethod diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 538c2a450..9dfe9aca0 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, MLATokenToKVPoolHost, ) -from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.metrics.collector import StorageMetricsCollector logger = logging.getLogger(__name__) @@ -570,7 +570,9 @@ class HiRadixCache(RadixCache): written_indices = host_indices[:min_completed_tokens] matched_length = self._insert_helper_host( last_host_node, - fetched_token_ids, + RadixKey( + token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key + ), written_indices, hash_value[: min_completed_tokens // self.page_size], ) @@ -592,7 +594,7 @@ class HiRadixCache(RadixCache): return True - def match_prefix(self, key: List[int], **kwargs): + def match_prefix(self, key: RadixKey, **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) if self.disable or len(key) == 0: return MatchResult( @@ -666,7 +668,9 @@ class HiRadixCache(RadixCache): ) self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) - def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): + def _insert_helper_host( + self, node: TreeNode, key: RadixKey, host_value, hash_value + ): node.last_access_time = time.monotonic() if len(key) == 0: return 0 @@ -700,7 +704,7 @@ class HiRadixCache(RadixCache): node.children[child_key] = new_node return matched_length - def _match_prefix_helper(self, node: TreeNode, key: List): + def _match_prefix_helper(self, node: TreeNode, key: RadixKey): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) value = [] @@ -726,7 +730,7 @@ class HiRadixCache(RadixCache): return value, node - def _split_node(self, key, child: TreeNode, split_len: int): + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): # child node split into new_node -> child new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} @@ -753,7 +757,7 @@ class HiRadixCache(RadixCache): new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node - def insert(self, key: List, value, chunked=False): + def insert(self, key: RadixKey, value=None, chunked=False): if len(key) == 0: return 0 @@ -811,7 +815,7 @@ class HiRadixCache(RadixCache): for idx in range(0, len(key), self.page_size): new_node.hash_value.append( self.cache_controller.get_hash_str( - key[idx : idx + self.page_size], + key.token_ids[idx : idx + self.page_size], prior_hash=last_hash, ) ) diff --git a/python/sglang/srt/mem_cache/lora_radix_cache.py b/python/sglang/srt/mem_cache/lora_radix_cache.py deleted file mode 100644 index 32b115cb6..000000000 --- a/python/sglang/srt/mem_cache/lora_radix_cache.py +++ /dev/null @@ -1,421 +0,0 @@ -"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes.""" - -import heapq -import time -from collections import defaultdict -from typing import TYPE_CHECKING, Any, List, Optional - -import torch - -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool - -if TYPE_CHECKING: - from sglang.srt.managers.schedule_batch import Req -else: - Req = Any # Placeholder for Req type when not type checking - - -class LoRAKey: - - def __init__(self, lora_id: str, token_ids: List[int]): - self.lora_id = ( - lora_id # lora_id of adaptor, should be hash value of adaptor path - ) - self.token_ids = token_ids # token_ids of the key - - def __len__(self): - return len(self.token_ids) - - -def get_child_key(key: LoRAKey): - # Here the key of children dict is the hash of lora_id + str(token_ids[0]) - # So the child key can be matched only when lora_id and token_ids[0] are the same - if key.lora_id is None: - return hash(str(key.token_ids[0])) - else: - return hash(key.lora_id + str(key.token_ids[0])) - - -class LoRATreeNode: - - counter = 0 - - def __init__(self, id: Optional[int] = None): - self.children = defaultdict(LoRATreeNode) - self.parent: LoRATreeNode = None - self.key: LoRAKey = None - self.value: Optional[torch.Tensor] = None - self.lock_ref = 0 - self.last_access_time = time.monotonic() - - self.id = LoRATreeNode.counter if id is None else id - LoRATreeNode.counter += 1 - - @property - def evicted(self): - return self.value is None - - def __lt__(self, other: "LoRATreeNode"): - return self.last_access_time < other.last_access_time - - -def _key_match(key0: LoRAKey, key1: LoRAKey): - if key0.lora_id != key1.lora_id: - raise ValueError( - f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}" - ) - i = 0 - for k0, k1 in zip(key0.token_ids, key1.token_ids): - if k0 != k1: - break - i += 1 - return i - - -class LoRARadixCache(BasePrefixCache): - - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - page_size: int, - disable: bool = False, - ): - if page_size > 1: - raise ValueError("LoRARadixCache currently only supports page_size = 1") - - if token_to_kv_pool_allocator is None: - raise ValueError( - "token_to_kv_pool_allocator is required to run LoraRadixCache" - ) - - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.page_size = page_size - self.disable = disable - self.device = self.token_to_kv_pool_allocator.device - - self.key_match_fn = _key_match - self.get_child_key_fn = get_child_key - self.reset() - - def reset(self): - self.root_node = LoRATreeNode() - self.root_node.key = LoRAKey(lora_id="", token_ids=[]) - self.root_node.value = None - self.evictable_size_ = 0 - self.protected_size_ = 0 - - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: - raise ValueError( - "LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead." - ) - - def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult: - """Find the matching prefix from the lora radix tree. - Args: - key: A LoRAKey to find a matching prefix. - Returns: - A tuple of a tensor of matching prefix token IDs and - the last node that contains the prefix values. Note that - this API can modify the internal state of the Radix tree. - The last node create a new child if the prefix is shorter - than the last node's value. - """ - if self.disable or len(key) == 0: - return MatchResult( - device_indices=torch.empty( - (0,), - dtype=torch.int64, - device=self.device, - ), - last_device_node=self.root_node, - last_host_node=self.root_node, - ) - - value, last_node = self._match_prefix_helper(self.root_node, key) - if value: - value = torch.cat(value) - else: - value = torch.empty((0,), dtype=torch.int64, device=self.device) - return MatchResult( - device_indices=value, - last_device_node=last_node, - last_host_node=last_node, - ) - - def insert(self, key: LoRAKey, value=None): - if self.disable: - return 0 - - if value is None: - value = [x for x in key.token_ids] - return self._insert_helper(self.root_node, key, value) - - def cache_finished_req(self, req: Req): - """Cache request when it finishes.""" - if self.disable: - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 - ] - self.token_to_kv_pool_allocator.free(kv_indices) - self.req_to_token_pool.free(req.req_pool_idx) - return - - token_ids = (req.origin_input_ids + req.output_ids)[:-1] - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) - ] - - page_aligned_len = len(kv_indices) - page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) - - # Radix Cache takes one ref in memory pool - lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len]) - new_prefix_len = self.insert(lora_key, page_aligned_kv_indices) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) - - # Remove req slot release the cache lock - self.req_to_token_pool.free(req.req_pool_idx) - self.dec_lock_ref(req.last_node) - - def cache_unfinished_req(self, req: Req, chunked=False): - """Cache request when it is unfinished.""" - if self.disable: - return - - token_ids = req.fill_ids - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) - ] - - page_aligned_len = len(kv_indices) - page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) - page_aligned_token_ids = token_ids[:page_aligned_len] - - # Radix Cache takes one ref in memory pool - inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids) - new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) - - # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key) - self.req_to_token_pool.write( - (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), - new_indices[len(req.prefix_indices) :], - ) - - self.dec_lock_ref(req.last_node) - self.inc_lock_ref(new_last_node) - - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later - req.prefix_indices = new_indices - req.last_node = new_last_node - - def pretty_print(self): - self._print_helper(self.root_node, 0) - print(f"#tokens: {self.total_size()}") - - def total_size(self): - return self._total_size_helper() - - def evict(self, num_tokens: int): - if self.disable: - return - - leaves = self._collect_leaves() - heapq.heapify(leaves) - - num_evicted = 0 - while num_evicted < num_tokens and len(leaves): - x = heapq.heappop(leaves) - - if x == self.root_node: - break - if x.lock_ref > 0: - continue - - self.token_to_kv_pool_allocator.free(x.value) - num_evicted += len(x.value) - self._delete_leaf(x) - - if len(x.parent.children) == 0: - heapq.heappush(leaves, x.parent) - - def inc_lock_ref(self, node: LoRATreeNode): - if self.disable: - return 0 - - delta = 0 - while node != self.root_node: - if node.lock_ref == 0: - self.evictable_size_ -= len(node.value) - self.protected_size_ += len(node.value) - delta -= len(node.value) - node.lock_ref += 1 - node = node.parent - return delta - - def dec_lock_ref(self, node: LoRATreeNode): - if self.disable: - return 0 - - delta = 0 - while node != self.root_node: - if node.lock_ref == 1: - self.evictable_size_ += len(node.value) - self.protected_size_ -= len(node.value) - delta += len(node.value) - node.lock_ref -= 1 - node = node.parent - return delta - - def evictable_size(self): - return self.evictable_size_ - - def protected_size(self): - # protected size refers to the size of the cache that is locked - return self.protected_size_ - - def all_values_flatten(self): - values = [] - - def _dfs_helper(node: LoRATreeNode): - for _, child in node.children.items(): - values.append(child.value) - _dfs_helper(child) - - _dfs_helper(self.root_node) - return torch.cat(values) - - ##### Internal Helper Functions ##### - - def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey): - node.last_access_time = time.monotonic() - - child_key = self.get_child_key_fn(key) - - value = [] - while len(key) > 0 and child_key in node.children.keys(): - child = node.children[child_key] - child.last_access_time = time.monotonic() - prefix_len = self.key_match_fn(child.key, key) - if prefix_len < len(child.key): - new_node = self._split_node(child.key, child, prefix_len) - value.append(new_node.value) - node = new_node - break - else: - value.append(child.value) - node = child - key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:]) - - if len(key): - child_key = self.get_child_key_fn(key) - - return value, node - - def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int): - # new_node -> child - new_node = LoRATreeNode() - key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len]) - key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:]) - new_node.children = {self.get_child_key_fn(key_split_2): child} - new_node.parent = child.parent - new_node.lock_ref = child.lock_ref - new_node.key = key_split_1 - new_node.value = child.value[:split_len] - child.parent = new_node - child.key = key_split_2 - child.value = child.value[split_len:] - new_node.parent.children[self.get_child_key_fn(key)] = new_node - - return new_node - - def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value): - node.last_access_time = time.monotonic() - if len(key) == 0: - return 0 - - child_key = self.get_child_key_fn(key) - - total_prefix_length = 0 - while len(key) > 0 and child_key in node.children.keys(): - node = node.children[child_key] - node.last_access_time = time.monotonic() - prefix_len = self.key_match_fn(node.key, key) - total_prefix_length += prefix_len - key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:]) - value = value[prefix_len:] - - if prefix_len < len(node.key): - new_node = self._split_node(node.key, node, prefix_len) - node = new_node - - if len(key): - child_key = self.get_child_key_fn(key) - - if len(key): - new_node = LoRATreeNode() - new_node.parent = node - new_node.key = key - new_node.value = value - node.children[child_key] = new_node - self.evictable_size_ += len(value) - return total_prefix_length - - def _print_helper(self, node: LoRATreeNode, indent: int): - """Prints the radix tree in a human-readable format.""" - stack = [(node, indent)] - while stack: - current_node, current_indent = stack.pop() - print( - " " * current_indent, - len(current_node.key), - current_node.key.token_ids[:10], - f"r={current_node.lock_ref}", - ) - for key, child in current_node.children.items(): - stack.append((child, current_indent + 2)) - - assert key == self.get_child_key_fn( - child.key - ), f"{key=}, {self.get_child_key_fn(child.key)=}" - - def _delete_leaf(self, node): - for k, v in node.parent.children.items(): - if v == node: - break - del node.parent.children[k] - self.evictable_size_ -= len(node.key) - - def _total_size_helper(self): - total_size = 0 - stack = [self.root_node] - while stack: - current_node = stack.pop() - total_size += len(current_node.value) - for child in current_node.children.values(): - if child.evicted: - continue - stack.append(child) - return total_size - - def _collect_leaves(self): - ret_list = [] - stack = [self.root_node] - - while stack: - cur_node = stack.pop() - if len(cur_node.children) == 0: - ret_list.append(cur_node) - else: - stack.extend(cur_node.children.values()) - - return ret_list diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 4745811de..abb9445f8 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -23,7 +23,7 @@ import heapq import time from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union import torch @@ -41,6 +41,30 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req +class RadixKey: + + def __init__(self, token_ids: List[int], extra_key: Optional[str] = None): + # token ids sequence + self.token_ids = token_ids + # extra key (e.g. lora_id, cache_salt) + self.extra_key = extra_key + + def __len__(self) -> int: + return len(self.token_ids) + + def __iter__(self) -> Iterator[int]: + return iter(self.token_ids) + + def __getitem__(self, idx: Union[int, slice]) -> "RadixKey": + if isinstance(idx, slice): + return RadixKey(self.token_ids[idx], self.extra_key) + return RadixKey([self.token_ids[idx]], self.extra_key) + + def __repr__(self) -> str: + preview = self.token_ids[:10] + return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})" + + class TreeNode: counter = 0 @@ -48,7 +72,7 @@ class TreeNode: def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent: TreeNode = None - self.key: List[int] = None + self.key: RadixKey = None self.value: Optional[torch.Tensor] = None self.lock_ref = 0 self.last_access_time = time.monotonic() @@ -94,27 +118,47 @@ class TreeNode: return self.last_access_time < other.last_access_time -def _key_match_page_size1(key0: List, key1: List): +def _check_extra_key(key0: RadixKey, key1: RadixKey): + if key0.extra_key != key1.extra_key: + raise ValueError( + f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}" + ) + + +def _key_match_page_size1(key0: RadixKey, key1: RadixKey): + _check_extra_key(key0, key1) i = 0 - for k0, k1 in zip(key0, key1): + for k0, k1 in zip(key0.token_ids, key1.token_ids): if k0 != k1: break i += 1 return i -def _key_match_paged(key0: List, key1: List, page_size: int): +def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int): + _check_extra_key(key0, key1) min_len = min(len(key0), len(key1)) i = 0 while i < min_len: - if key0[i : i + page_size] != key1[i : i + page_size]: + if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]: break i += page_size return i +def get_child_key(key: RadixKey, page_size: int = 1): + if page_size == 1: + plain_key = key.token_ids[0] + else: + plain_key = tuple(key.token_ids[:page_size]) + if key.extra_key is None: + return plain_key + else: + return (key.extra_key, plain_key) + + class RadixCache(BasePrefixCache): def __init__( self, @@ -139,10 +183,10 @@ class RadixCache(BasePrefixCache): if self.page_size == 1: self.key_match_fn = _key_match_page_size1 - self.get_child_key_fn = lambda key: key[0] + self.get_child_key_fn = get_child_key else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) - self.get_child_key_fn = lambda key: tuple(key[:page_size]) + self.get_child_key_fn = partial(get_child_key, page_size=page_size) if eviction_policy.lower() == "lru": self.eviction_strategy: EvictionStrategy = LRUStrategy() @@ -158,7 +202,7 @@ class RadixCache(BasePrefixCache): def reset(self): self.root_node = TreeNode() - self.root_node.key = [] + self.root_node.key = RadixKey(token_ids=[], extra_key=None) self.root_node.value = [] self.root_node.host_value = [] self.root_node.lock_ref = 1 @@ -166,16 +210,43 @@ class RadixCache(BasePrefixCache): self.protected_size_ = 0 self._record_all_cleared_event() - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: - """Find the matching prefix from the radix tree. + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + """Find the longest cached prefix of ``key`` in the radix tree. + + The logical namespace for prefix matching is determined by both the + token id sequence and the optional ``extra_key`` carried by ``RadixKey``. + Entries that share identical leading token ids but have *different* + ``extra_key`` values are intentionally kept disjoint and never share + prefix nodes. This is useful to: + + * Isolate KV cache lines for different LoRA / adapter IDs. + * Separate requests that intentionally should not share state (e.g., + different sampling salt, cache version, or retrieval augmentation + context) by supplying a distinct ``extra_key``. + Args: - key: A list of token IDs to find a matching prefix. + key (RadixKey): The lookup key containing a list of token ids and an + optional ``extra_key`` namespace tag. If ``page_size > 1`` the + length is internally truncated to a multiple of ``page_size`` + before matching. Passing an empty key returns an empty result + with the root as the last node. + **kwargs: Reserved for future extensions (ignored currently). + Returns: - A tuple of a tensor of matching prefix token IDs and - the last node that contains the prefix values. Note that - this API can modify the internal state of the Radix tree. - The last node create a new child if the prefix is shorter - than the last node's value. + MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of + the concatenated KV cache indices corresponding to the longest + cached prefix (may be length 0). ``last_device_node`` and + ``last_host_node`` (currently the same) are the tree node objects + representing the terminal node of the matched prefix. This method + may mutate internal structure by splitting an existing node if the + match ends inside a stored segment. + + Internal updates: + * Refreshes access metadata (timestamps) used by the + configured eviction strategy. + * If the lookup ends inside a stored segment the node is split once + to expose a precise boundary; this structural refinement improves + subsequent match efficiency and does not duplicate data. """ if self.disable or len(key) == 0: return MatchResult( @@ -203,12 +274,12 @@ class RadixCache(BasePrefixCache): last_host_node=last_node, ) - def insert(self, key: List, value=None, chunked=False): + def insert(self, key: RadixKey, value=None, chunked=False): if self.disable: return 0 if value is None: - value = [x for x in key] + value = torch.tensor(key.token_ids, dtype=torch.int64) return self._insert_helper(self.root_node, key, value) def cache_finished_req(self, req: Req): @@ -238,7 +309,8 @@ class RadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( - token_ids[:page_aligned_len], page_aligned_kv_indices + RadixKey(token_ids[:page_aligned_len], req.extra_key), + page_aligned_kv_indices, ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] @@ -270,14 +342,18 @@ class RadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( - page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked + RadixKey(page_aligned_token_ids, req.extra_key), + page_aligned_kv_indices, + chunked=chunked, ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) + new_indices, new_last_node, _, _ = self.match_prefix( + RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) + ) self.req_to_token_pool.write( (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), new_indices[len(req.prefix_indices) :], @@ -379,7 +455,7 @@ class RadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper(self, node: TreeNode, key: List): + def _match_prefix_helper(self, node: TreeNode, key: RadixKey): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) @@ -404,7 +480,7 @@ class RadixCache(BasePrefixCache): return value, node - def _split_node(self, key, child: TreeNode, split_len: int): + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): # new_node -> child self._record_remove_event(child) new_node = TreeNode() @@ -423,7 +499,7 @@ class RadixCache(BasePrefixCache): return new_node - def _insert_helper(self, node: TreeNode, key: List, value): + def _insert_helper(self, node: TreeNode, key: RadixKey, value): node.last_access_time = time.monotonic() if len(key) == 0: return 0 @@ -464,7 +540,7 @@ class RadixCache(BasePrefixCache): print( " " * current_indent, len(current_node.key), - current_node.key[:10], + current_node.key.token_ids[:10], f"r={current_node.lock_ref}", ) for key, child in current_node.children.items(): @@ -516,11 +592,11 @@ class RadixCache(BasePrefixCache): last_page_start = ( (len(node.parent.key) - 1) // self.page_size ) * self.page_size - parent_parent_tokens = node.parent.key[last_page_start:] + parent_parent_tokens = node.parent.key.token_ids[last_page_start:] parent_block_hash = hash(tuple(parent_parent_tokens)) for start in range(0, len(node.key), self.page_size): - page_tokens = node.key[start : start + self.page_size] + page_tokens = node.key.token_ids[start : start + self.page_size] if not page_tokens: continue @@ -543,7 +619,7 @@ class RadixCache(BasePrefixCache): # One BlockRemoved per chunk. if self.enable_kv_cache_events: for start in range(0, len(node.key), self.page_size): - page_tokens = node.key[start : start + self.page_size] + page_tokens = node.key.token_ids[start : start + self.page_size] if not page_tokens: continue block_hash = hash(tuple(page_tokens)) @@ -569,19 +645,12 @@ class RadixCache(BasePrefixCache): if __name__ == "__main__": tree = RadixCache(None, None, page_size=1, disable=False) - tree.insert("Hello") - tree.insert("Hello") - tree.insert("Hello_L.A.!") - # tree.insert("Hello_world! Happy") - # tree.insert("I love you!") + # Example token id sequences (as lists of ints) + tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None)) + tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None)) + tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None)) + tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None)) + tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None)) tree.pretty_print() - # print(tree.match_prefix("I love you! aha")) - - # def evict_callback(x): - # print("evict", x) - # return len(x) - - # tree.evict(5, evict_callback) - # tree.evict(10, evict_callback) - # tree.pretty_print() + print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None))) diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index e9512e83f..a16b989fb 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import ( TreeNodeCpp, ) from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixKey if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache): raise NotImplementedError("Host cache is not supported yet") self.tree.reset() - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: device_indices_vec, host_indices_length, node_gpu, node_cpu = ( - self.tree.match_prefix(key) + self.tree.match_prefix(key.token_ids) ) return MatchResult( device_indices=self._merge_tensor(device_indices_vec), @@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache): host_hit_length=host_indices_length, ) - def _insert(self, key: List[int], value: torch.Tensor) -> int: + def _insert(self, key: RadixKey, value: torch.Tensor) -> int: """ Insert a key-value pair into the radix tree. Args: - key (List[int]): The key to insert, represented as a list of integers. + key (RadixKey): The key to insert, represented as a RadixKey. value (torch.Tensor): The value to associate with the key. Returns: int: Number of device indices that were already present in the tree before the insertion. """ - ongoing_write, length = self.tree.writing_through(key, value) + ongoing_write, length = self.tree.writing_through(key.token_ids, value) if self.cache_controller is None: assert len(ongoing_write) == 0, "Implementation error" return length @@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache): # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # it will automatically align them, but length of them should be equal old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size - new_prefix_len = self._insert(token_ids, kv_indices) + new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices) # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" @@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache): # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # it will automatically align them, but length of them should be equal old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size - new_prefix_len = self._insert(token_ids, kv_indices) + new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices) # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function) # The prefix indices need to updated to reuse the kv indices in the pool - new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids) + new_indices_vec, _, new_last_node, _ = self.tree.match_prefix( + RadixKey(token_ids, req.extra_key).token_ids + ) new_indices = self._merge_tensor(new_indices_vec) assert new_prefix_len <= len(new_indices) diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index 99537135e..36061ac14 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -9,7 +9,7 @@ import torch from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode try: from lmcache.integration.sglang.sglang_adapter import ( @@ -131,7 +131,7 @@ class LMCRadixCache(RadixCache): with self._node_lock: self._in_flight_nodes.clear() - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override] + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override] """Match cached prefix; if there's a tail miss, prefetch from LMCache. Reuses the base matching logic to obtain (value, last_node). If there @@ -178,7 +178,7 @@ class LMCRadixCache(RadixCache): with torch.cuda.stream(self.load_stream): num_retrieved = self.lmcache_connector.start_load_kv( LoadMetadata( - token_ids=key, # full page-aligned key + token_ids=key.token_ids, # full page-aligned key slot_mapping=slot_mapping, offset=value.numel() - prefix_pad, # LMCache offset convention ) @@ -227,7 +227,7 @@ class LMCRadixCache(RadixCache): req.req_pool_idx, : len(token_ids) ] - _, new_last_node, _, _ = self.match_prefix(token_ids) + _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key)) assert new_last_node is not None self.inc_lock_ref(new_last_node) @@ -277,6 +277,8 @@ if __name__ == "__main__": rank=0, tp_group=None, ) - cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64)) - cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64)) + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64)) + cache.insert( + RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64) + ) cache.pretty_print() diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 686fc6ab0..592f1198f 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -30,6 +30,12 @@ import torch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import ( + RadixKey, + _key_match_page_size1, + _key_match_paged, + get_child_key, +) if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -47,7 +53,7 @@ class TreeNode: def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent: TreeNode = None - self.key: List[int] = None + self.key: RadixKey = None self.value: Optional[torch.Tensor] = None # swa_tombstone is used to indicate the kv indices have been freed for swa layers self.swa_tombstone = False @@ -87,27 +93,6 @@ class TreeNode: return self.last_access_time < other.last_access_time -def _key_match_page_size1(key0: List, key1: List): - i = 0 - for k0, k1 in zip(key0, key1): - if k0 != k1: - break - i += 1 - return i - - -def _key_match_paged(key0: List, key1: List, page_size: int): - min_len = min(len(key0), len(key1)) - - i = 0 - while i < min_len: - if key0[i : i + page_size] != key1[i : i + page_size]: - break - i += page_size - - return i - - def gen_swa_uuid() -> int: TreeNode.swa_uuid_counter += 1 return TreeNode.swa_uuid_counter @@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache): if self.page_size == 1: self.key_match_fn = _key_match_page_size1 - self.get_child_key_fn = lambda key: key[0] + self.get_child_key_fn = get_child_key else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) - self.get_child_key_fn = lambda key: tuple(key[:page_size]) + self.get_child_key_fn = partial(get_child_key, page_size=page_size) self.sliding_window_size = sliding_window_size self.reset() @@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache): self.full_lru_list = LRUList(swa=False) self.swa_lru_list = LRUList(swa=True) - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: """Find the matching prefix from the radix tree. Args: - key: A list of token IDs to find a matching prefix. + key: A RadixKey contains token IDs to find a matching prefix. Returns: A tuple of a tensor of matching prefix token IDs and the last node that contains the prefix values. Note that @@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache): last_host_node=last_node, ) - def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int: + def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: if self.disable: return 0 if value is None: - value = [x for x in key] + value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) return self._insert_helper(self.root_node, key, value, prev_prefix_len) def cache_finished_req(self, req: Req) -> None: @@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache): # insert the token_ids and kv_indices into the radix tree # Note: the insert function already frees the overlapped kv_indices new_prefix_len = self.insert( - token_ids[:page_aligned_len], + RadixKey(token_ids[:page_aligned_len], req.extra_key), page_aligned_kv_indices, len(req.prefix_indices), ) @@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool # Note: the insert function already frees the overlapped kv_indices new_prefix_len = self.insert( - page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices) + RadixKey(page_aligned_token_ids, req.extra_key), + page_aligned_kv_indices, + len(req.prefix_indices), ) # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) + new_indices, new_last_node, _, _ = self.match_prefix( + RadixKey(page_aligned_token_ids, req.extra_key) + ) assert len(req.prefix_indices) <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}" @@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]: + def _match_prefix_helper( + self, key: RadixKey + ) -> Tuple[List[torch.Tensor], TreeNode]: """ SWA prefix matching helper. It factors in the sliding window size such that the matched node is guaranteed to either 1. connected to root without swa tombstone, @@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache): return value[:best_value_len], best_last_node - def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode: + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: # new_node -> child new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} @@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache): return new_node def _insert_helper( - self, node: TreeNode, key: List, value, update_kv_after_len: int + self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int ) -> int: # Update the last access time from root to leaf, so that # swa will tombstone the node closer to root first diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ab0c17fa9..2794bfe3f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -99,6 +99,7 @@ suites = { TestFile("test_priority_scheduling.py", 100), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 105), + TestFile("test_radix_cache_unit.py", 5), TestFile("test_regex_constrained.py", 64), TestFile("test_reasoning_parser.py", 5), TestFile("test_retract_decode.py", 54), diff --git a/test/srt/test_radix_cache_unit.py b/test/srt/test_radix_cache_unit.py new file mode 100644 index 000000000..8cb75fb0b --- /dev/null +++ b/test/srt/test_radix_cache_unit.py @@ -0,0 +1,597 @@ +""" +Unit tests for the RadixCache implementation. + +This module tests the core functionality of RadixCache, RadixKey, and TreeNode +following SGLang testing patterns. + +Test Coverage: +- RadixKey: token ID management, slicing, iteration, representation +- TreeNode: node properties, reference counting, hash values +- RadixCache: insert/match operations, eviction, page alignment, error handling +- Cache events and request handling +- Boundary conditions with parameterized testing + +Usage: + python test_radix_cache_unit.py + python -m pytest test_radix_cache_unit.py -v + python -m pytest test_radix_cache_unit.py::TestRadixCache::test_insert_basic +""" + +import time +import unittest +import unittest.mock + +import torch + +from sglang.srt.disaggregation.kv_events import BlockRemoved, BlockStored +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode + +# Test constants +DEFAULT_PAGE_SIZE = 4 + + +class TestRadixKey(unittest.TestCase): + """Test cases for RadixKey class.""" + + def test_init_basic(self): + """Test basic initialization of RadixKey.""" + token_ids = [1, 2, 3, 4] + key = RadixKey(token_ids) + self.assertEqual(key.token_ids, token_ids) + self.assertIsNone(key.extra_key) + + def test_init_with_extra_key(self): + """Test initialization with extra_key.""" + token_ids = [1, 2, 3] + extra_key = "test_key" + key = RadixKey(token_ids, extra_key) + self.assertEqual(key.token_ids, token_ids) + self.assertEqual(key.extra_key, extra_key) + + def test_len(self): + """Test __len__ method.""" + key = RadixKey([1, 2, 3]) + self.assertEqual(len(key), 3) + + empty_key = RadixKey([]) + self.assertEqual(len(empty_key), 0) + + def test_iter(self): + """Test __iter__ method.""" + token_ids = [1, 2, 3, 4] + key = RadixKey(token_ids) + self.assertEqual(list(key), token_ids) + + def test_len_and_iter(self): + """Test __len__ and __iter__ methods.""" + test_cases = [ + ([1, 2, 3], 3), + ([], 0), + ([42], 1), + ] + + for tokens, expected in test_cases: + with self.subTest(tokens=tokens): + key = RadixKey(tokens) + self.assertEqual(len(key), expected) + self.assertEqual(list(key), tokens) + + def test_getitem_int(self): + """Test __getitem__ with int index.""" + test_cases = [ + ([10, 20, 30], 0, [10]), + ([10, 20, 30], -1, [30]), + ([10, 20, 30], 2, [30]), + ] + + for tokens, index, expected in test_cases: + with self.subTest(tokens=tokens, index=index): + key = RadixKey(tokens) + result = key[index] + self.assertIsInstance(result, RadixKey) + self.assertEqual(result.token_ids, expected) + + def test_getitem_slice(self): + """Test __getitem__ with slice and edge cases.""" + key = RadixKey([1, 2, 3, 4, 5], "extra") + + # Basic slice + sliced = key[1:4] + self.assertIsInstance(sliced, RadixKey) + self.assertEqual(sliced.token_ids, [2, 3, 4]) + self.assertEqual(sliced.extra_key, "extra") + + # Edge cases + self.assertEqual(key[2:2].token_ids, []) # Empty slice + self.assertEqual(key[:].token_ids, [1, 2, 3, 4, 5]) # Full slice + + def test_getitem_invalid_index(self): + """Test __getitem__ with invalid indices.""" + key = RadixKey([1, 2, 3]) + with self.assertRaises(IndexError): + _ = key[10] # Out of bounds + + def test_repr(self): + """Test __repr__ method.""" + key = RadixKey([1, 2, 3], "test") + repr_str = repr(key) + self.assertIn("RadixKey", repr_str) + self.assertIn("extra_key='test'", repr_str) + self.assertIn("[1, 2, 3]", repr_str) + + def test_repr_long_token_ids(self): + """Test __repr__ with long token_ids.""" + long_tokens = list(range(15)) + key = RadixKey(long_tokens) + repr_str = repr(key) + self.assertIn("...", repr_str) # Should be truncated + + +class TestTreeNode(unittest.TestCase): + """Test cases for TreeNode class.""" + + def setUp(self): + """Reset the counter before each test.""" + TreeNode.counter = 0 + + def test_init_basic(self): + """Test basic initialization of TreeNode.""" + node = TreeNode() + self.assertEqual(node.id, 0) + self.assertEqual(len(node.children), 0) + self.assertIsNone(node.parent) + self.assertIsNone(node.key) + self.assertIsNone(node.value) + self.assertEqual(node.lock_ref, 0) + self.assertEqual(node.hit_count, 0) + self.assertEqual(node.host_ref_counter, 0) + self.assertIsNone(node.host_value) + self.assertIsNone(node.hash_value) + + def test_init_with_id(self): + """Test initialization with custom ID.""" + node = TreeNode(id=42) + self.assertEqual(node.id, 42) + node2 = TreeNode() + self.assertEqual(node2.id, 1) # Counter was incremented + + def test_counter_increment(self): + """Test that counter increments properly.""" + node1 = TreeNode() + node2 = TreeNode() + self.assertEqual(node1.id, 0) + self.assertEqual(node2.id, 1) + + def test_evicted_backuped_properties(self): + """Test evicted and backuped properties.""" + test_cases = [ + (False, False, True, False), + (True, False, False, False), + (True, True, False, True), + (False, True, True, True), + ] + + for ( + has_value, + has_host_value, + expected_evicted, + expected_backuped, + ) in test_cases: + with self.subTest(has_value=has_value, has_host_value=has_host_value): + node = TreeNode() + + if has_value: + node.value = torch.tensor([1, 2, 3]) + if has_host_value: + node.host_value = torch.tensor([4, 5, 6]) + + self.assertEqual(node.evicted, expected_evicted) + self.assertEqual(node.backuped, expected_backuped) + + def test_protect_release_host(self): + """Test protect_host and release_host methods.""" + node = TreeNode() + self.assertEqual(node.host_ref_counter, 0) + + node.protect_host() + self.assertEqual(node.host_ref_counter, 1) + + node.release_host() + self.assertEqual(node.host_ref_counter, 0) + + # Test error case + with self.assertRaises(RuntimeError): + node.release_host() + + def test_get_last_hash_value(self): + """Test get_last_hash_value method.""" + node = TreeNode() + self.assertIsNone(node.get_last_hash_value()) + + node.hash_value = ["hash1", "hash2", "hash3"] + self.assertEqual(node.get_last_hash_value(), "hash3") + + def test_lt_comparison(self): + """Test less than comparison based on last_access_time.""" + node1 = TreeNode() + time.sleep(0.001) # Small delay to ensure different timestamps + node2 = TreeNode() + + self.assertTrue(node1 < node2) + self.assertFalse(node2 < node1) + + +class TestRadixCache(unittest.TestCase): + """Test cases for RadixCache class.""" + + def setUp(self): + """Set up test fixtures.""" + TreeNode.counter = 0 + + def test_init_variations(self): + """Test cache initialization with different parameters.""" + test_cases = [ + (1, False, False), + (4, False, True), + (1, True, False), + ] + + for page_size, disable, enable_events in test_cases: + with self.subTest( + page_size=page_size, disable=disable, enable_events=enable_events + ): + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=page_size, + disable=disable, + enable_kv_cache_events=enable_events, + ) + + self.assertEqual(cache.page_size, page_size) + self.assertEqual(cache.disable, disable) + self.assertEqual(cache.enable_kv_cache_events, enable_events) + self.assertEqual(cache.device, torch.device("cpu")) + self.assertIsNotNone(cache.root_node) + self.assertEqual(len(cache.root_node.key), 0) + + def test_reset(self): + """Test reset method.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + # Insert some data + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) + self.assertGreater(cache.total_size(), 0) + + # Reset + cache.reset() + self.assertEqual(cache.total_size(), 0) + self.assertEqual(cache.evictable_size(), 0) + self.assertEqual(cache.protected_size(), 0) + + def test_insert_and_match_basic(self): + """Test basic insert and match operations.""" + for disable_cache in [False, True]: + with self.subTest(disable_cache=disable_cache): + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=1, + disable=disable_cache, + ) + + key = RadixKey([1, 2, 3]) + value = torch.tensor([10, 20, 30], dtype=torch.int64) + prefix_len = cache.insert(key, value) + + if disable_cache: + self.assertEqual(prefix_len, 0) + self.assertEqual(cache.total_size(), 0) + continue + + self.assertEqual(prefix_len, 0) # No existing prefix + self.assertEqual(cache.total_size(), 3) + self.assertEqual(cache.evictable_size(), 3) + + # Test match_prefix + result = cache.match_prefix(RadixKey([1, 2, 3])) + self.assertEqual(len(result.device_indices), 3) + torch.testing.assert_close(result.device_indices, value) + + # Test partial match + result = cache.match_prefix(RadixKey([1, 2])) + self.assertEqual(len(result.device_indices), 2) + torch.testing.assert_close( + result.device_indices, torch.tensor([10, 20], dtype=torch.int64) + ) + + def test_insert_with_none_value(self): + """Test insert with None value (should use token_ids as list).""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + key = RadixKey([1, 2, 3]) + prefix_len = cache.insert(key, None) + + # When None is passed, it should create value from token_ids + self.assertEqual(prefix_len, 0) + self.assertEqual(cache.total_size(), 3) + + def test_total_size(self): + """Test total_size calculation.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + self.assertEqual(cache.total_size(), 0) + + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) + self.assertEqual(cache.total_size(), 3) + + cache.insert(RadixKey([4, 5]), torch.tensor([40, 50], dtype=torch.int64)) + self.assertEqual(cache.total_size(), 5) + + def test_kv_cache_events(self): + """Test KV cache events functionality.""" + test_cases = [ + (1, True), + (2, True), + (1, False), + ] + + for page_size, enable_events in test_cases: + with self.subTest(page_size=page_size, enable_events=enable_events): + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=page_size, + enable_kv_cache_events=enable_events, + ) + + # Insert data + cache.insert(RadixKey([1, 2, 3, 4, 5]), None) + + # Take events + events = cache.take_events() + + if enable_events: + self.assertGreater(len(events), 0) + # Verify events include BlockStored events (there might be other event types) + block_stored_events = [ + e for e in events if isinstance(e, BlockStored) + ] + self.assertGreater(len(block_stored_events), 0) + for event in block_stored_events: + self.assertLessEqual(len(event.token_ids), page_size) + else: + self.assertEqual(len(events), 0) + + def test_kv_cache_events_with_eviction(self): + """Test KV cache events include removal events.""" + mock_allocator = unittest.mock.Mock() + mock_allocator.device = torch.device("cpu") + + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=mock_allocator, + page_size=1, + enable_kv_cache_events=True, + ) + + # Insert and then evict data + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) + cache.evict(3) + + # Take events - should include both store and remove events + events = cache.take_events() + self.assertGreater(len(events), 0) + + # Check event types + event_types = [type(event).__name__ for event in events] + self.assertIn("BlockStored", event_types) + + # Verify BlockRemoved event content + remove_events = [e for e in events if isinstance(e, BlockRemoved)] + for event in remove_events: + self.assertGreater(len(event.block_hashes), 0) + + def test_extra_key_isolation(self): + """Test that keys with different extra_key values are isolated.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + # Insert same token sequence with different extra keys + cache.insert( + RadixKey([1, 2, 3], "key1"), torch.tensor([10, 20, 30], dtype=torch.int64) + ) + cache.insert( + RadixKey([1, 2, 3], "key2"), torch.tensor([40, 50, 60], dtype=torch.int64) + ) + cache.insert( + RadixKey([1, 2, 3], None), torch.tensor([70, 80, 90], dtype=torch.int64) + ) + + # Keys with different extra_key should not match each other + result1 = cache.match_prefix(RadixKey([1, 2, 3], "key1")) + result2 = cache.match_prefix(RadixKey([1, 2, 3], "key2")) + result3 = cache.match_prefix(RadixKey([1, 2, 3], None)) + result4 = cache.match_prefix(RadixKey([1, 2, 3], "nonexistent")) + + # Each should match only its own data + self.assertEqual(len(result1.device_indices), 3) + torch.testing.assert_close( + result1.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64) + ) + + self.assertEqual(len(result2.device_indices), 3) + torch.testing.assert_close( + result2.device_indices, torch.tensor([40, 50, 60], dtype=torch.int64) + ) + + self.assertEqual(len(result3.device_indices), 3) + torch.testing.assert_close( + result3.device_indices, torch.tensor([70, 80, 90], dtype=torch.int64) + ) + + # Non-existent extra_key should not match + self.assertEqual(len(result4.device_indices), 0) + + def test_lock_ref_operations(self): + """Test lock reference counting operations.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + # Insert sequence + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) + + # Get node + result = cache.match_prefix(RadixKey([1, 2, 3])) + node = result.last_device_node + + initial_evictable = cache.evictable_size() + initial_protected = cache.protected_size() + + # Lock the node + cache.inc_lock_ref(node) + self.assertEqual(cache.protected_size(), initial_protected + 3) + self.assertEqual(cache.evictable_size(), initial_evictable - 3) + + # Unlock the node + cache.dec_lock_ref(node) + self.assertEqual(cache.protected_size(), initial_protected) + self.assertEqual(cache.evictable_size(), initial_evictable) + + def test_evict_functionality(self): + """Test eviction functionality.""" + mock_allocator = unittest.mock.Mock() + mock_allocator.device = torch.device("cpu") + + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=mock_allocator, + page_size=1, + ) + + # Insert sequences + cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64)) + cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64)) + + initial_size = cache.total_size() + + # Evict some tokens + cache.evict(2) + + # Should have called free and reduced size + mock_allocator.free.assert_called() + self.assertLess(cache.total_size(), initial_size) + + def test_page_alignment_boundary(self): + """Test page alignment with different sizes.""" + test_cases = [ + (1, 5), + (2, 5), + (4, 6), + ] + + for page_size, sequence_length in test_cases: + with self.subTest(page_size=page_size, sequence_length=sequence_length): + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=page_size, + ) + + tokens = list(range(sequence_length)) + cache.insert(RadixKey(tokens), torch.tensor(tokens, dtype=torch.int64)) + + result = cache.match_prefix(RadixKey(tokens)) + self.assertGreater(len(result.device_indices), 0) + + # Match length should be page-aligned + match_len = len(result.device_indices) + self.assertEqual(match_len % page_size, 0) + + def test_pretty_print_basic(self): + """Test pretty_print produces output.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) + + # Just test that it doesn't crash + try: + cache.pretty_print() + except Exception as e: + self.fail(f"pretty_print raised an exception: {e}") + + def test_all_values_flatten(self): + """Test all_values_flatten method.""" + cache = RadixCache( + req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 + ) + + cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64)) + cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64)) + + all_values = cache.all_values_flatten() + self.assertEqual(len(all_values), 4) + # Values should contain all inserted values (order may vary) + values_set = set(all_values.tolist()) + self.assertEqual(values_set, {10, 20, 30, 40}) + + def test_advanced_prefix_match_with_node_splits(self): + """Advanced prefix matching: splits inside nodes and across pages.""" + for page_size in [1, 2]: + with self.subTest(page_size=page_size): + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=page_size, + ) + + # Insert a long sequence that will be split later. + seq1 = [1, 2, 3, 4, 5, 6, 7, 8] + val1 = torch.tensor([x * 10 for x in seq1], dtype=torch.int64) + cache.insert(RadixKey(seq1), val1) + + # Insert a diverging branch to create an internal node on the path. + seq2 = [1, 2, 9, 10] + val2 = torch.tensor([x * 10 for x in seq2], dtype=torch.int64) + cache.insert(RadixKey(seq2), val2) + print(cache.pretty_print()) + + baseline_total = cache.total_size() + expected_total = 10 # 8 + 2 + self.assertEqual(baseline_total, expected_total) + + # Match that causes a split inside an existing node: + # take first 4 tokens of seq1, then diverge. + query1 = [1, 2, 3, 4, 999, 1000] + result1 = cache.match_prefix(RadixKey(query1)) + torch.testing.assert_close(result1.device_indices, val1[:4]) + # No data change after structural split during matching. + self.assertEqual(cache.total_size(), baseline_total) + + # Full match of the long sequence still returns the full indices. + result_full = cache.match_prefix(RadixKey(seq1)) + torch.testing.assert_close(result_full.device_indices, val1) + + # Another split deeper on the path (after matching 6 tokens, then diverge). + query2 = [1, 2, 3, 4, 5, 6, 777, 888] + result2 = cache.match_prefix(RadixKey(query2)) + torch.testing.assert_close(result2.device_indices, val1[:6]) + self.assertEqual(cache.total_size(), baseline_total) + + # Matching the short diverging branch should return exactly its indices. + result_branch = cache.match_prefix(RadixKey(seq2)) + torch.testing.assert_close(result_branch.device_indices, val2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py index 128462029..68c76e1f5 100644 --- a/test/srt/test_swa_unittest.py +++ b/test/srt/test_swa_unittest.py @@ -4,7 +4,8 @@ import torch from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -from sglang.srt.mem_cache.radix_cache import SWARadixCache +from sglang.srt.mem_cache.radix_cache import RadixKey +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache class TestSWA(unittest.TestCase): @@ -19,7 +20,7 @@ class TestSWA(unittest.TestCase): def test_swa_memory_pool(self): size = 16 size_swa = 16 - num_head = 8 + head_num = 8 head_dim = 128 num_layers = 48 global_interval = 4 @@ -34,14 +35,20 @@ class TestSWA(unittest.TestCase): size=size, size_swa=size_swa, dtype=dtype, - num_head=num_head, + head_num=head_num, head_dim=head_dim, swa_attention_layer_ids=swa_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, device=device, ) alloc = SWATokenToKVPoolAllocator( - size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool + size=size, + size_swa=size_swa, + dtype=dtype, + device=device, + kvcache=pool, + need_sort=False, ) assert alloc.available_size() == size + size_swa index = alloc.alloc(1) @@ -57,7 +64,7 @@ class TestSWA(unittest.TestCase): kv_size = 128 kv_size_swa = 64 sliding_window_size = 4 - num_head = 8 + head_num = 8 head_dim = 128 num_layers = 48 global_interval = 4 @@ -80,10 +87,11 @@ class TestSWA(unittest.TestCase): size=kv_size, size_swa=kv_size_swa, dtype=dtype, - num_head=num_head, + head_num=head_num, head_dim=head_dim, swa_attention_layer_ids=swa_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, device=device, ) # setup token to kv pool allocator @@ -93,6 +101,7 @@ class TestSWA(unittest.TestCase): dtype=dtype, device=device, kvcache=kv_pool, + need_sort=False, ) # setup radix cache tree = SWARadixCache( @@ -112,7 +121,7 @@ class TestSWA(unittest.TestCase): print( f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" ) - prefix_len = tree.insert(req1_token_ids, req1_kv_indices) + prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices) print( f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) @@ -121,7 +130,7 @@ class TestSWA(unittest.TestCase): print( f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" ) - prefix_len = tree.insert(req2_token_ids, req2_kv_indices) + prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices) print( f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) @@ -130,7 +139,7 @@ class TestSWA(unittest.TestCase): print( f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" ) - prefix_len = tree.insert(req3_token_ids, req3_kv_indices) + prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices) print( f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) @@ -139,7 +148,7 @@ class TestSWA(unittest.TestCase): print( f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" ) - prefix_len = tree.insert(req4_token_ids, req4_kv_indices) + prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices) print( f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) @@ -161,21 +170,23 @@ class TestSWA(unittest.TestCase): tree.pretty_print() req5_token_ids = [1, 2, 3, 4, 5] - kv_indices, last_node = tree.match_prefix(req5_token_ids) + result = tree.match_prefix(RadixKey(req5_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node print( f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" ) assert len(kv_indices) == 0 req6_token_ids = [1, 2, 3, 4, 5, 60, 70] - kv_indices, last_node = tree.match_prefix(req6_token_ids) + result = tree.match_prefix(RadixKey(req6_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node print( f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" ) assert len(kv_indices) == 7 assert len(last_node.key) == 2 - assert last_node.key[0] == 60 - assert last_node.key[1] == 70 + assert last_node.key.token_ids[0] == 60 + assert last_node.key.token_ids[1] == 70 if __name__ == "__main__":