Refactors radix cache for extra key support (#10317)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
Xinyuan Tong
2025-09-21 11:16:16 -07:00
committed by GitHub
parent fc3e542009
commit 12d6cf18f0
13 changed files with 821 additions and 574 deletions

View File

@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__)
@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache):
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
fetched_token_ids,
RadixKey(
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
),
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache):
return True
def match_prefix(self, key: List[int], **kwargs):
def match_prefix(self, key: RadixKey, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0:
return MatchResult(
@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache):
)
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
def _insert_helper_host(
self, node: TreeNode, key: RadixKey, host_value, hash_value
):
node.last_access_time = time.monotonic()
if len(key) == 0:
return 0
@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node
return matched_length
def _match_prefix_helper(self, node: TreeNode, key: List):
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)
value = []
@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache):
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
# child node split into new_node -> child
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache):
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def insert(self, key: List, value, chunked=False):
def insert(self, key: RadixKey, value=None, chunked=False):
if len(key) == 0:
return 0
@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache):
for idx in range(0, len(key), self.page_size):
new_node.hash_value.append(
self.cache_controller.get_hash_str(
key[idx : idx + self.page_size],
key.token_ids[idx : idx + self.page_size],
prior_hash=last_hash,
)
)