Refactors radix cache for extra key support (#10317)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import (
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
@@ -457,6 +457,7 @@ class Req:
|
||||
vocab_size: Optional[int] = None,
|
||||
priority: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
extra_key: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -489,6 +490,14 @@ class Req:
|
||||
self.sampling_params = sampling_params
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||
if lora_id is not None:
|
||||
extra_key = (
|
||||
extra_key or ""
|
||||
) + lora_id # lora_id is concatenated to the extra key
|
||||
|
||||
self.extra_key = extra_key
|
||||
self.lora_id = lora_id
|
||||
|
||||
# Memory pool info
|
||||
@@ -679,26 +688,16 @@ class Req:
|
||||
):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
if isinstance(tree_cache, LoRARadixCache):
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix_with_lora_id(
|
||||
key=LoRAKey(
|
||||
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
||||
),
|
||||
)
|
||||
else:
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(),
|
||||
)
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=RadixKey(
|
||||
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
||||
),
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -175,10 +175,13 @@ class SchedulePolicy:
|
||||
|
||||
for r in waiting_queue:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
extra_key = r.extra_key
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
||||
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
||||
self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
@@ -191,7 +194,8 @@ class SchedulePolicy:
|
||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
||||
in_batch_matching_prefixes, _, _, _ = (
|
||||
self.waiting_queue_radix_tree.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
rid=r.rid,
|
||||
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
||||
)
|
||||
)
|
||||
if (
|
||||
@@ -202,7 +206,8 @@ class SchedulePolicy:
|
||||
else:
|
||||
# Insert with a dummy key
|
||||
self.waiting_queue_radix_tree.insert(
|
||||
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
||||
RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
||||
torch.empty(len(prefix_ids), dtype=torch.bool),
|
||||
)
|
||||
return temporary_deprioritized
|
||||
|
||||
|
||||
@@ -145,7 +145,6 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
@@ -719,19 +718,6 @@ class Scheduler(
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif self.enable_lora:
|
||||
assert (
|
||||
not self.enable_hierarchical_cache
|
||||
), "LoRA radix cache doesn't support hierarchical cache"
|
||||
assert (
|
||||
self.schedule_policy == "fcfs"
|
||||
), "LoRA radix cache only supports FCFS policy"
|
||||
self.tree_cache = LoRARadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
LMCRadixCache,
|
||||
|
||||
@@ -36,7 +36,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: Any, **kwargs) -> MatchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.metrics.collector import StorageMetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache):
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
RadixKey(
|
||||
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
||||
),
|
||||
written_indices,
|
||||
hash_value[: min_completed_tokens // self.page_size],
|
||||
)
|
||||
@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return True
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
def match_prefix(self, key: RadixKey, **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
def _insert_helper_host(
|
||||
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
||||
):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache):
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
value = []
|
||||
@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache):
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def insert(self, key: List, value, chunked=False):
|
||||
def insert(self, key: RadixKey, value=None, chunked=False):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache):
|
||||
for idx in range(0, len(key), self.page_size):
|
||||
new_node.hash_value.append(
|
||||
self.cache_controller.get_hash_str(
|
||||
key[idx : idx + self.page_size],
|
||||
key.token_ids[idx : idx + self.page_size],
|
||||
prior_hash=last_hash,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
else:
|
||||
Req = Any # Placeholder for Req type when not type checking
|
||||
|
||||
|
||||
class LoRAKey:
|
||||
|
||||
def __init__(self, lora_id: str, token_ids: List[int]):
|
||||
self.lora_id = (
|
||||
lora_id # lora_id of adaptor, should be hash value of adaptor path
|
||||
)
|
||||
self.token_ids = token_ids # token_ids of the key
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
def get_child_key(key: LoRAKey):
|
||||
# Here the key of children dict is the hash of lora_id + str(token_ids[0])
|
||||
# So the child key can be matched only when lora_id and token_ids[0] are the same
|
||||
if key.lora_id is None:
|
||||
return hash(str(key.token_ids[0]))
|
||||
else:
|
||||
return hash(key.lora_id + str(key.token_ids[0]))
|
||||
|
||||
|
||||
class LoRATreeNode:
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(LoRATreeNode)
|
||||
self.parent: LoRATreeNode = None
|
||||
self.key: LoRAKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
self.id = LoRATreeNode.counter if id is None else id
|
||||
LoRATreeNode.counter += 1
|
||||
|
||||
@property
|
||||
def evicted(self):
|
||||
return self.value is None
|
||||
|
||||
def __lt__(self, other: "LoRATreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match(key0: LoRAKey, key1: LoRAKey):
|
||||
if key0.lora_id != key1.lora_id:
|
||||
raise ValueError(
|
||||
f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
|
||||
)
|
||||
i = 0
|
||||
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
class LoRARadixCache(BasePrefixCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
):
|
||||
if page_size > 1:
|
||||
raise ValueError("LoRARadixCache currently only supports page_size = 1")
|
||||
|
||||
if token_to_kv_pool_allocator is None:
|
||||
raise ValueError(
|
||||
"token_to_kv_pool_allocator is required to run LoraRadixCache"
|
||||
)
|
||||
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
|
||||
self.key_match_fn = _key_match
|
||||
self.get_child_key_fn = get_child_key
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.root_node = LoRATreeNode()
|
||||
self.root_node.key = LoRAKey(lora_id="", token_ids=[])
|
||||
self.root_node.value = None
|
||||
self.evictable_size_ = 0
|
||||
self.protected_size_ = 0
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
raise ValueError(
|
||||
"LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
|
||||
)
|
||||
|
||||
def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the lora radix tree.
|
||||
Args:
|
||||
key: A LoRAKey to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
),
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
)
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: LoRAKey, value=None):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key.token_ids]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
|
||||
new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
|
||||
new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
print(f"#tokens: {self.total_size()}")
|
||||
|
||||
def total_size(self):
|
||||
return self._total_size_helper()
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x == self.root_node:
|
||||
break
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
self.token_to_kv_pool_allocator.free(x.value)
|
||||
num_evicted += len(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_lock_ref(self, node: LoRATreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
self.protected_size_ += len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.lock_ref += 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def dec_lock_ref(self, node: LoRATreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
self.protected_size_ -= len(node.value)
|
||||
delta += len(node.value)
|
||||
node.lock_ref -= 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def protected_size(self):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
def all_values_flatten(self):
|
||||
values = []
|
||||
|
||||
def _dfs_helper(node: LoRATreeNode):
|
||||
for _, child in node.children.items():
|
||||
values.append(child.value)
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.cat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
new_node = LoRATreeNode()
|
||||
key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
|
||||
key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
|
||||
new_node.children = {self.get_child_key_fn(key_split_2): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = key_split_1
|
||||
new_node.value = child.value[:split_len]
|
||||
child.parent = new_node
|
||||
child.key = key_split_2
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
total_prefix_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
total_prefix_length += prefix_len
|
||||
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
||||
value = value[prefix_len:]
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = LoRATreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return total_prefix_length
|
||||
|
||||
def _print_helper(self, node: LoRATreeNode, indent: int):
|
||||
"""Prints the radix tree in a human-readable format."""
|
||||
stack = [(node, indent)]
|
||||
while stack:
|
||||
current_node, current_indent = stack.pop()
|
||||
print(
|
||||
" " * current_indent,
|
||||
len(current_node.key),
|
||||
current_node.key.token_ids[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for key, child in current_node.children.items():
|
||||
stack.append((child, current_indent + 2))
|
||||
|
||||
assert key == self.get_child_key_fn(
|
||||
child.key
|
||||
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.evictable_size_ -= len(node.key)
|
||||
|
||||
def _total_size_helper(self):
|
||||
total_size = 0
|
||||
stack = [self.root_node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
total_size += len(current_node.value)
|
||||
for child in current_node.children.values():
|
||||
if child.evicted:
|
||||
continue
|
||||
stack.append(child)
|
||||
return total_size
|
||||
|
||||
def _collect_leaves(self):
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
if len(cur_node.children) == 0:
|
||||
ret_list.append(cur_node)
|
||||
else:
|
||||
stack.extend(cur_node.children.values())
|
||||
|
||||
return ret_list
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -41,6 +41,30 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class RadixKey:
|
||||
|
||||
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
|
||||
# token ids sequence
|
||||
self.token_ids = token_ids
|
||||
# extra key (e.g. lora_id, cache_salt)
|
||||
self.extra_key = extra_key
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.token_ids)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return iter(self.token_ids)
|
||||
|
||||
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
|
||||
if isinstance(idx, slice):
|
||||
return RadixKey(self.token_ids[idx], self.extra_key)
|
||||
return RadixKey([self.token_ids[idx]], self.extra_key)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
preview = self.token_ids[:10]
|
||||
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
|
||||
|
||||
|
||||
class TreeNode:
|
||||
|
||||
counter = 0
|
||||
@@ -48,7 +72,7 @@ class TreeNode:
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.key: RadixKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
@@ -94,27 +118,47 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
def _check_extra_key(key0: RadixKey, key1: RadixKey):
|
||||
if key0.extra_key != key1.extra_key:
|
||||
raise ValueError(
|
||||
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
|
||||
)
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
|
||||
_check_extra_key(key0, key1)
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
|
||||
_check_extra_key(key0, key1)
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def get_child_key(key: RadixKey, page_size: int = 1):
|
||||
if page_size == 1:
|
||||
plain_key = key.token_ids[0]
|
||||
else:
|
||||
plain_key = tuple(key.token_ids[:page_size])
|
||||
if key.extra_key is None:
|
||||
return plain_key
|
||||
else:
|
||||
return (key.extra_key, plain_key)
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -139,10 +183,10 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
self.get_child_key_fn = get_child_key
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
if eviction_policy.lower() == "lru":
|
||||
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
||||
@@ -158,7 +202,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
def reset(self):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.key = []
|
||||
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
|
||||
self.root_node.value = []
|
||||
self.root_node.host_value = []
|
||||
self.root_node.lock_ref = 1
|
||||
@@ -166,16 +210,43 @@ class RadixCache(BasePrefixCache):
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
"""Find the longest cached prefix of ``key`` in the radix tree.
|
||||
|
||||
The logical namespace for prefix matching is determined by both the
|
||||
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
|
||||
Entries that share identical leading token ids but have *different*
|
||||
``extra_key`` values are intentionally kept disjoint and never share
|
||||
prefix nodes. This is useful to:
|
||||
|
||||
* Isolate KV cache lines for different LoRA / adapter IDs.
|
||||
* Separate requests that intentionally should not share state (e.g.,
|
||||
different sampling salt, cache version, or retrieval augmentation
|
||||
context) by supplying a distinct ``extra_key``.
|
||||
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
key (RadixKey): The lookup key containing a list of token ids and an
|
||||
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
|
||||
length is internally truncated to a multiple of ``page_size``
|
||||
before matching. Passing an empty key returns an empty result
|
||||
with the root as the last node.
|
||||
**kwargs: Reserved for future extensions (ignored currently).
|
||||
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
|
||||
the concatenated KV cache indices corresponding to the longest
|
||||
cached prefix (may be length 0). ``last_device_node`` and
|
||||
``last_host_node`` (currently the same) are the tree node objects
|
||||
representing the terminal node of the matched prefix. This method
|
||||
may mutate internal structure by splitting an existing node if the
|
||||
match ends inside a stored segment.
|
||||
|
||||
Internal updates:
|
||||
* Refreshes access metadata (timestamps) used by the
|
||||
configured eviction strategy.
|
||||
* If the lookup ends inside a stored segment the node is split once
|
||||
to expose a precise boundary; this structural refinement improves
|
||||
subsequent match efficiency and does not duplicate data.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
@@ -203,12 +274,12 @@ class RadixCache(BasePrefixCache):
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None, chunked=False):
|
||||
def insert(self, key: RadixKey, value=None, chunked=False):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
@@ -238,7 +309,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len], page_aligned_kv_indices
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
@@ -270,14 +342,18 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
chunked=chunked,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
@@ -379,7 +455,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
@@ -404,7 +480,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
self._record_remove_event(child)
|
||||
new_node = TreeNode()
|
||||
@@ -423,7 +499,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
def _insert_helper(self, node: TreeNode, key: RadixKey, value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -464,7 +540,7 @@ class RadixCache(BasePrefixCache):
|
||||
print(
|
||||
" " * current_indent,
|
||||
len(current_node.key),
|
||||
current_node.key[:10],
|
||||
current_node.key.token_ids[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for key, child in current_node.children.items():
|
||||
@@ -516,11 +592,11 @@ class RadixCache(BasePrefixCache):
|
||||
last_page_start = (
|
||||
(len(node.parent.key) - 1) // self.page_size
|
||||
) * self.page_size
|
||||
parent_parent_tokens = node.parent.key[last_page_start:]
|
||||
parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
|
||||
parent_block_hash = hash(tuple(parent_parent_tokens))
|
||||
|
||||
for start in range(0, len(node.key), self.page_size):
|
||||
page_tokens = node.key[start : start + self.page_size]
|
||||
page_tokens = node.key.token_ids[start : start + self.page_size]
|
||||
if not page_tokens:
|
||||
continue
|
||||
|
||||
@@ -543,7 +619,7 @@ class RadixCache(BasePrefixCache):
|
||||
# One BlockRemoved per chunk.
|
||||
if self.enable_kv_cache_events:
|
||||
for start in range(0, len(node.key), self.page_size):
|
||||
page_tokens = node.key[start : start + self.page_size]
|
||||
page_tokens = node.key.token_ids[start : start + self.page_size]
|
||||
if not page_tokens:
|
||||
continue
|
||||
block_hash = hash(tuple(page_tokens))
|
||||
@@ -569,19 +645,12 @@ class RadixCache(BasePrefixCache):
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello_L.A.!")
|
||||
# tree.insert("Hello_world! Happy")
|
||||
# tree.insert("I love you!")
|
||||
# Example token id sequences (as lists of ints)
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
|
||||
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
|
||||
tree.pretty_print()
|
||||
|
||||
# print(tree.match_prefix("I love you! aha"))
|
||||
|
||||
# def evict_callback(x):
|
||||
# print("evict", x)
|
||||
# return len(x)
|
||||
|
||||
# tree.evict(5, evict_callback)
|
||||
# tree.evict(10, evict_callback)
|
||||
# tree.pretty_print()
|
||||
print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
||||
TreeNodeCpp,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
raise NotImplementedError("Host cache is not supported yet")
|
||||
self.tree.reset()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
||||
self.tree.match_prefix(key)
|
||||
self.tree.match_prefix(key.token_ids)
|
||||
)
|
||||
return MatchResult(
|
||||
device_indices=self._merge_tensor(device_indices_vec),
|
||||
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
host_hit_length=host_indices_length,
|
||||
)
|
||||
|
||||
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
||||
def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
|
||||
"""
|
||||
Insert a key-value pair into the radix tree.
|
||||
Args:
|
||||
key (List[int]): The key to insert, represented as a list of integers.
|
||||
key (RadixKey): The key to insert, represented as a RadixKey.
|
||||
value (torch.Tensor): The value to associate with the key.
|
||||
Returns:
|
||||
int: Number of device indices that were already present in the tree before the insertion.
|
||||
"""
|
||||
ongoing_write, length = self.tree.writing_through(key, value)
|
||||
ongoing_write, length = self.tree.writing_through(key.token_ids, value)
|
||||
if self.cache_controller is None:
|
||||
assert len(ongoing_write) == 0, "Implementation error"
|
||||
return length
|
||||
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
|
||||
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
||||
# The prefix indices need to updated to reuse the kv indices in the pool
|
||||
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
||||
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
||||
RadixKey(token_ids, req.extra_key).token_ids
|
||||
)
|
||||
new_indices = self._merge_tensor(new_indices_vec)
|
||||
assert new_prefix_len <= len(new_indices)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
|
||||
try:
|
||||
from lmcache.integration.sglang.sglang_adapter import (
|
||||
@@ -131,7 +131,7 @@ class LMCRadixCache(RadixCache):
|
||||
with self._node_lock:
|
||||
self._in_flight_nodes.clear()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
|
||||
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
||||
|
||||
Reuses the base matching logic to obtain (value, last_node). If there
|
||||
@@ -178,7 +178,7 @@ class LMCRadixCache(RadixCache):
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
num_retrieved = self.lmcache_connector.start_load_kv(
|
||||
LoadMetadata(
|
||||
token_ids=key, # full page-aligned key
|
||||
token_ids=key.token_ids, # full page-aligned key
|
||||
slot_mapping=slot_mapping,
|
||||
offset=value.numel() - prefix_pad, # LMCache offset convention
|
||||
)
|
||||
@@ -227,7 +227,7 @@ class LMCRadixCache(RadixCache):
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
||||
_, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
|
||||
assert new_last_node is not None
|
||||
|
||||
self.inc_lock_ref(new_last_node)
|
||||
@@ -277,6 +277,8 @@ if __name__ == "__main__":
|
||||
rank=0,
|
||||
tp_group=None,
|
||||
)
|
||||
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
||||
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||
cache.insert(
|
||||
RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
|
||||
)
|
||||
cache.pretty_print()
|
||||
|
||||
@@ -30,6 +30,12 @@ import torch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import (
|
||||
RadixKey,
|
||||
_key_match_page_size1,
|
||||
_key_match_paged,
|
||||
get_child_key,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -47,7 +53,7 @@ class TreeNode:
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.key: RadixKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
# swa_tombstone is used to indicate the kv indices have been freed for swa layers
|
||||
self.swa_tombstone = False
|
||||
@@ -87,27 +93,6 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def gen_swa_uuid() -> int:
|
||||
TreeNode.swa_uuid_counter += 1
|
||||
return TreeNode.swa_uuid_counter
|
||||
@@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
self.get_child_key_fn = get_child_key
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.reset()
|
||||
@@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
|
||||
self.full_lru_list = LRUList(swa=False)
|
||||
self.swa_lru_list = LRUList(swa=True)
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
key: A RadixKey contains token IDs to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
@@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
|
||||
def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
||||
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
||||
|
||||
def cache_finished_req(self, req: Req) -> None:
|
||||
@@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
# insert the token_ids and kv_indices into the radix tree
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len],
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
)
|
||||
@@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key)
|
||||
)
|
||||
assert len(req.prefix_indices) <= len(
|
||||
new_indices
|
||||
), f"{req.prefix_indices=}, {new_indices=}"
|
||||
@@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
|
||||
def _match_prefix_helper(
|
||||
self, key: RadixKey
|
||||
) -> Tuple[List[torch.Tensor], TreeNode]:
|
||||
"""
|
||||
SWA prefix matching helper. It factors in the sliding window size such that
|
||||
the matched node is guaranteed to either 1. connected to root without swa tombstone,
|
||||
@@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
return value[:best_value_len], best_last_node
|
||||
|
||||
def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
@@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
|
||||
return new_node
|
||||
|
||||
def _insert_helper(
|
||||
self, node: TreeNode, key: List, value, update_kv_after_len: int
|
||||
self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
|
||||
) -> int:
|
||||
# Update the last access time from root to leaf, so that
|
||||
# swa will tombstone the node closer to root first
|
||||
|
||||
Reference in New Issue
Block a user