Refactors radix cache for extra key support (#10317)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.metrics.collector import StorageMetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache):
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
RadixKey(
|
||||
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
||||
),
|
||||
written_indices,
|
||||
hash_value[: min_completed_tokens // self.page_size],
|
||||
)
|
||||
@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return True
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
def match_prefix(self, key: RadixKey, **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
def _insert_helper_host(
|
||||
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
||||
):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache):
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
value = []
|
||||
@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache):
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def insert(self, key: List, value, chunked=False):
|
||||
def insert(self, key: RadixKey, value=None, chunked=False):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache):
|
||||
for idx in range(0, len(key), self.page_size):
|
||||
new_node.hash_value.append(
|
||||
self.cache_controller.get_hash_str(
|
||||
key[idx : idx + self.page_size],
|
||||
key.token_ids[idx : idx + self.page_size],
|
||||
prior_hash=last_hash,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user