Files
sglang/python/sglang/srt/mem_cache/hiradix_cache.py
2025-10-10 00:22:05 -07:00

950 lines
35 KiB
Python

import heapq
import json
import logging
import threading
import time
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__)
class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
hicache_io_backend: str,
hicache_mem_layout: str,
enable_metrics: bool,
eviction_policy: str = "lru",
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
is_eagle: bool = False,
):
if hicache_io_backend == "direct":
if hicache_mem_layout == "page_first":
hicache_mem_layout = "page_first_direct"
logger.warning(
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics
(
extra_config,
prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
hicache_storage_pass_prefix_keys,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
page_size,
self.tp_group,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold,
model_name=model_name,
storage_backend_extra_config=extra_config,
)
if self.enable_storage_metrics:
# TODO: support pp
labels = {
"storage_backend": hicache_storage_backend,
"tp_rank": self.cache_controller.tp_rank,
"dp_rank": self.cache_controller.dp_rank,
}
self.metrics_collector = StorageMetricsCollector(labels=labels)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
self.ongoing_load_back = {}
# record the ongoing prefetch requests
self.ongoing_prefetch = {}
self.ongoing_backup = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 2
)
self.load_back_threshold = 10
super().__init__(
req_to_token_pool,
token_to_kv_pool_allocator,
page_size,
disable=False,
eviction_policy=eviction_policy,
is_eagle=is_eagle,
)
def _parse_storage_backend_extra_config(
self, storage_backend_extra_config: Optional[str]
):
"""
Parse storage backend extra config JSON and extract specific parameters.
Args:
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
"""
# Parse extra config JSON if provided
extra_config = {}
if storage_backend_extra_config:
try:
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
raise e
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
hicache_storage_pass_prefix_keys = extra_config.pop(
"hicache_storage_pass_prefix_keys", False
)
if not isinstance(prefetch_threshold, int):
raise ValueError(
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
)
if not isinstance(prefetch_timeout_base, (int, float)):
raise ValueError(
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
)
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
raise ValueError(
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
)
return (
extra_config,
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
hicache_storage_pass_prefix_keys,
)
def reset(self):
TreeNode.counter = 0
self.cache_controller.reset()
self.token_to_kv_pool_host.clear()
super().reset()
def get_height(self, node: TreeNode):
height = 0
while node != self.root_node:
node = node.parent
height += 1
return height
def clear_storage_backend(self) -> bool:
if self.enable_storage:
try:
# Check if the storage backend has a clear method (for nixl backends)
if hasattr(self.cache_controller.storage_backend, "clear"):
self.cache_controller.storage_backend.clear()
logger.info(
"Hierarchical cache storage backend cleared successfully!"
)
return True
else:
logger.warning(
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
)
return False
except Exception as e:
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
return False
else:
logger.warning("Hierarchical cache storage backend is not enabled.")
return False
def write_backup(self, node: TreeNode, write_back=False):
host_indices = self.cache_controller.write(
device_indices=node.value,
node_id=node.id,
)
if host_indices is None:
self.evict_host(len(node.value))
host_indices = self.cache_controller.write(
device_indices=node.value,
node_id=node.id,
)
if host_indices is not None:
node.host_value = host_indices
assert len(node.host_value) > 0
self.ongoing_write_through[node.id] = node
if not write_back:
# no need to lock nodes if write back
self.inc_lock_ref(node)
else:
return 0
return len(host_indices)
def write_backup_storage(self, node: TreeNode):
prefix_keys = (
node.get_prefix_hash_values(node.parent)
if self.hicache_storage_pass_prefix_keys
else None
)
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.hash_value, prefix_keys
)
self.ongoing_backup[operation_id] = node
node.protect_host()
def _inc_hit_count(self, node: TreeNode, chunked=False):
# skip the hit count update for chunked requests
if self.cache_controller.write_policy == "write_back" or chunked:
return
node.hit_count += 1
if not node.backuped:
if node.hit_count >= self.write_through_threshold:
# write to host if the node is not backuped
self.write_backup(node)
def writing_check(self, write_back=False):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
finish_event.synchronize()
for ack_id in ack_list:
del self.ongoing_write_through[ack_id]
self.cache_controller.ack_write_queue.clear()
assert len(self.ongoing_write_through) == 0
return
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
if len(self.ongoing_write_through) == 0:
return
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query():
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1:
# synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
finish_count = int(queue_size.item())
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
backuped_node = self.ongoing_write_through.pop(ack_id)
self.dec_lock_ref(backuped_node)
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count -= 1
def loading_check(self):
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
if not finish_event.query():
# the KV cache loading is still ongoing
break
finish_count += 1
# no need to sync across TP workers as batch forwarding is synced
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id)
self.dec_lock_ref(end_node)
# ACK until all events are processed
del self.cache_controller.ack_load_queue[:finish_count]
def evictable_size(self):
return self.evictable_size_
def evict(self, num_tokens: int):
leaves = self._collect_leaves_device()
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
write_back_nodes = []
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x.lock_ref > 0:
continue
if not x.backuped:
if self.cache_controller.write_policy == "write_back":
# write to host if the node is not backuped
num_evicted += self.write_backup(x, write_back=True)
write_back_nodes.append(x)
else:
num_evicted += self._evict_regular(x)
else:
num_evicted += self._evict_backuped(x)
for child in x.parent.children.values():
if child in write_back_nodes:
continue
if not child.evicted:
break
else:
# all children are evicted or no children
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
if self.cache_controller.write_policy == "write_back":
self.writing_check(write_back=True)
for node in write_back_nodes:
assert node.backuped
self._evict_backuped(node)
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value)
assert num_evicted > 0
self.evictable_size_ -= num_evicted
node.value = None
return num_evicted
def _evict_regular(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
return num_evicted
def evict_host(self, num_tokens: int):
leaves = self._collect_leaves()
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x == self.root_node:
break
# only evict the host value of evicted nodes
if not x.evicted:
continue
# node is protected from eviction as it has ongoing prefetch or backup to storage
if x.host_ref_counter > 0:
continue
num_evicted += self.cache_controller.evict_host(x.host_value)
for k, v in x.parent.children.items():
if v == x:
break
del x.parent.children[k]
if len(x.parent.children) == 0 and x.parent.evicted:
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
def load_back(
self, node: TreeNode, mem_quota: Optional[int] = None
) -> Optional[torch.Tensor]:
# todo: more loading policies
last_hit_node = node
nodes_to_load = []
while node.evicted:
assert (
node.backuped
), "No backup available on evicted nodes, should not happen"
nodes_to_load.insert(0, node)
node = node.parent
else:
ancester_node = node
# protect the ancestor nodes from eviction
delta = self.inc_lock_ref(ancester_node)
# load it all or not at all
host_indices = torch.cat([n.host_value for n in nodes_to_load])
if len(host_indices) < self.load_back_threshold or (
len(host_indices) > mem_quota + delta if mem_quota is not None else False
):
# skip loading back if the total size is too small or exceeding the memory quota
self.dec_lock_ref(ancester_node)
return None
device_indices = self.cache_controller.load(
host_indices=host_indices, node_id=last_hit_node.id
)
if device_indices is None:
self.evict(len(host_indices))
device_indices = self.cache_controller.load(
host_indices=host_indices, node_id=last_hit_node.id
)
self.dec_lock_ref(ancester_node)
if device_indices is None:
# no sufficient GPU memory to load back KV caches
return None
self.ongoing_load_back[last_hit_node.id] = last_hit_node
offset = 0
for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value)
self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node)
return device_indices
def init_load_back(
self,
last_node: TreeNode,
host_hit_length: int,
mem_quota: Optional[int] = None,
):
_ = host_hit_length # unused, but kept for compatibility
if last_node.evicted:
loading_values = self.load_back(last_node, mem_quota)
if loading_values is not None:
logger.debug(
f"loading back {len(loading_values)} tokens for node {last_node.id}"
)
return loading_values, last_node
while last_node.evicted:
last_node = last_node.parent
return (
torch.empty((0,), dtype=torch.int64, device=self.device),
last_node,
)
def ready_to_load_host_cache(self) -> int:
"""
Notify the cache controller to start the KV cache loading.
Return the consumer index for the schedule batch manager to track.
"""
return self.cache_controller.start_loading()
def check_hicache_events(self):
self.writing_check()
self.loading_check()
if self.enable_storage:
self.drain_storage_control_queues()
if self.enable_storage_metrics:
self.metrics_collector.log_storage_metrics(
self.cache_controller.storage_backend.get_stats()
)
def drain_storage_control_queues(self):
"""
Combine prefetch revoke, backup ack, and host mem release checks
to minimize TP synchronization and Python overhead.
"""
cc = self.cache_controller
qsizes = torch.tensor(
[
cc.prefetch_revoke_queue.qsize(),
cc.ack_backup_queue.qsize(),
cc.host_mem_release_queue.qsize(),
],
dtype=torch.int,
)
if self.tp_world_size > 1:
torch.distributed.all_reduce(
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
# process prefetch revokes
for _ in range(n_revoke):
req_id = cc.prefetch_revoke_queue.get()
info = self.ongoing_prefetch.pop(req_id, None)
if info is not None:
last_host_node, token_ids, _, _ = info
last_host_node.release_host()
cc.prefetch_tokens_occupied -= len(token_ids)
# else: the revoked operation already got terminated, nothing to do
# process backup acks
for _ in range(n_backup):
operation = cc.ack_backup_queue.get()
ack_id = operation.id
entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None:
entry.release_host()
if self.enable_storage_metrics:
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
# release host memory
host_indices_list = []
for _ in range(n_release):
host_indices_list.append(cc.host_mem_release_queue.get())
if host_indices_list:
host_indices = torch.cat(host_indices_list, dim=0)
cc.mem_pool_host.free(host_indices)
# Timeout is linearly increasing with the number of pages
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
# If hash_value has not been computed in timeout_base seconds, terminate it.
return (
time.monotonic() - operation.start_time
> self.prefetch_timeout_base
+ len(operation.hash_value) * self.prefetch_timeout_per_page
)
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
if self.prefetch_stop_policy == "best_effort":
return can_terminate
if len(operation.hash_value) == 0:
completed = False
else:
completed = (
operation.completed_tokens == len(operation.hash_value) * self.page_size
)
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or self.is_prefetch_timeout(operation)
else:
# unknown prefetch stop policy, just return True
return True
operation_terminated = operation.is_terminated()
if self.tp_world_size > 1:
states = torch.tensor(
[1 - int(can_terminate), int(operation_terminated)],
dtype=torch.int,
)
torch.distributed.all_reduce(
states,
op=torch.distributed.ReduceOp.MAX,
group=self.tp_group,
)
can_terminate = states[0].item() == 0
operation_terminated = states[1].item() == 1
# the operation should be terminated if it is already terminated on any TP worker
# or it meets the termination condition on all TP workers
can_terminate = can_terminate or operation_terminated
return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool:
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return True
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id
]
if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
return True
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
RadixKey(
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
),
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
if self.enable_storage_metrics:
self.metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)
return True
def match_prefix(self, key: RadixKey, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=empty_value,
last_device_node=self.root_node,
last_host_node=self.root_node,
host_hit_length=0,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = empty_value
host_hit_length = 0
last_host_node = last_node
while last_node.evicted:
host_hit_length += len(last_node.host_value)
last_node = last_node.parent
while not last_host_node.backuped:
last_host_node = last_host_node.parent
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_host_node,
host_hit_length=host_hit_length,
)
def prefetch_from_storage(
self,
req_id: str,
last_host_node: TreeNode,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
prefix_keys: Optional[List[str]] = None,
):
# align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
len(new_input_tokens) % self.page_size
)
new_input_tokens = new_input_tokens[:prefetch_length]
if (
not self.enable_storage
or prefetch_length < self.prefetch_threshold
or self.cache_controller.prefetch_rate_limited()
):
return
last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory for prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
)
self.ongoing_prefetch[req_id] = (
last_host_node,
new_input_tokens,
host_indices,
operation,
)
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(
self, node: TreeNode, key: RadixKey, host_value, hash_value
):
node.last_access_time = time.monotonic()
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
matched_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key)
key = key[prefix_len:]
host_value = host_value[prefix_len:]
hash_value = hash_value[prefix_len // self.page_size :]
matched_length += prefix_len
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = None
new_node.host_value = host_value
new_node.hash_value = hash_value
node.children[child_key] = new_node
return matched_length
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
if not child.evicted:
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
# child node split into new_node -> child
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.hit_count = child.hit_count
# split value and host value if exists
if child.evicted:
new_node.value = None
else:
new_node.value = child.value[:split_len]
child.value = child.value[split_len:]
if child.backuped:
new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:]
if child.hash_value:
new_node.hash_value = child.hash_value[: split_len // self.page_size]
child.hash_value = child.hash_value[split_len // self.page_size :]
child.parent = new_node
child.key = child.key[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def insert(self, key: RadixKey, value=None, chunked=False):
key.token_ids = self.key_convert_fn(key.token_ids)
if len(key) == 0:
return 0
if self.is_eagle and value is not None:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
node = self.root_node
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key)
if prefix_len == len(node.key):
if node.evicted:
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
node.value = value[:prefix_len]
self.evictable_size_ += len(node.value)
else:
self._inc_hit_count(node, chunked)
total_prefix_length += prefix_len
else:
# partial match, split the node
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.evictable_size_ += len(new_node.value)
else:
self._inc_hit_count(new_node, chunked)
total_prefix_length += prefix_len
node = new_node
key = key[prefix_len:]
value = value[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
if self.enable_storage:
last_hash = node.get_last_hash_value()
assert (node == self.root_node) or (
last_hash is not None
), "Parent node must have a hash value with storage enabled"
new_node.hash_value = []
for idx in range(0, len(key), self.page_size):
new_node.hash_value.append(
self.cache_controller.get_hash_str(
key.token_ids[idx : idx + self.page_size],
prior_hash=last_hash,
)
)
last_hash = new_node.hash_value[-1]
if self.cache_controller.write_policy != "write_back":
self._inc_hit_count(new_node, chunked)
return total_prefix_length
def _collect_leaves_device(self):
def is_leaf(node):
if node.evicted:
return False
if node == self.root_node:
return False
if len(node.children) == 0:
return True
for child in node.children.values():
if not child.evicted:
return False
return True
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if is_leaf(cur_node):
ret_list.append(cur_node)
else:
for cur_child in cur_node.children.values():
if not cur_child.evicted:
stack.append(cur_child)
return ret_list
def release_aborted_request(self, rid: str):
if rid not in self.ongoing_prefetch:
return
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
if operation.host_indices is None:
return
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)
last_host_node.release_host()
del self.ongoing_prefetch[rid]
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)