Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
950 lines
35 KiB
Python
950 lines
35 KiB
Python
import heapq
|
|
import json
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
|
|
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
|
from sglang.srt.mem_cache.memory_pool import (
|
|
MHATokenToKVPool,
|
|
MLATokenToKVPool,
|
|
ReqToTokenPool,
|
|
)
|
|
from sglang.srt.mem_cache.memory_pool_host import (
|
|
MHATokenToKVPoolHost,
|
|
MLATokenToKVPoolHost,
|
|
)
|
|
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
|
from sglang.srt.metrics.collector import StorageMetricsCollector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HiRadixCache(RadixCache):
|
|
|
|
def __init__(
|
|
self,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
|
tp_cache_group: torch.distributed.ProcessGroup,
|
|
page_size: int,
|
|
hicache_ratio: float,
|
|
hicache_size: int,
|
|
hicache_write_policy: str,
|
|
hicache_io_backend: str,
|
|
hicache_mem_layout: str,
|
|
enable_metrics: bool,
|
|
eviction_policy: str = "lru",
|
|
hicache_storage_backend: Optional[str] = None,
|
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
|
model_name: Optional[str] = None,
|
|
storage_backend_extra_config: Optional[str] = None,
|
|
is_eagle: bool = False,
|
|
):
|
|
|
|
if hicache_io_backend == "direct":
|
|
if hicache_mem_layout == "page_first":
|
|
hicache_mem_layout = "page_first_direct"
|
|
logger.warning(
|
|
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
|
|
)
|
|
|
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
|
self.kv_cache,
|
|
hicache_ratio,
|
|
hicache_size,
|
|
page_size,
|
|
hicache_mem_layout,
|
|
)
|
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
|
self.kv_cache,
|
|
hicache_ratio,
|
|
hicache_size,
|
|
page_size,
|
|
hicache_mem_layout,
|
|
)
|
|
else:
|
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
|
|
|
self.tp_group = tp_cache_group
|
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
|
self.enable_storage = hicache_storage_backend is not None
|
|
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
|
|
|
(
|
|
extra_config,
|
|
prefetch_threshold,
|
|
prefetch_timeout_base,
|
|
prefetch_timeout_per_ki_token,
|
|
hicache_storage_pass_prefix_keys,
|
|
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
|
self.prefetch_threshold = prefetch_threshold
|
|
self.prefetch_timeout_base = prefetch_timeout_base
|
|
self.prefetch_timeout_per_page = (
|
|
page_size / 1024 * prefetch_timeout_per_ki_token
|
|
)
|
|
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
|
|
# TODO: support more timeout check functions
|
|
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
|
|
|
self.load_cache_event = threading.Event()
|
|
self.cache_controller = HiCacheController(
|
|
token_to_kv_pool_allocator,
|
|
self.token_to_kv_pool_host,
|
|
page_size,
|
|
self.tp_group,
|
|
load_cache_event=self.load_cache_event,
|
|
write_policy=hicache_write_policy,
|
|
io_backend=hicache_io_backend,
|
|
storage_backend=hicache_storage_backend,
|
|
prefetch_threshold=self.prefetch_threshold,
|
|
model_name=model_name,
|
|
storage_backend_extra_config=extra_config,
|
|
)
|
|
if self.enable_storage_metrics:
|
|
# TODO: support pp
|
|
labels = {
|
|
"storage_backend": hicache_storage_backend,
|
|
"tp_rank": self.cache_controller.tp_rank,
|
|
"dp_rank": self.cache_controller.dp_rank,
|
|
}
|
|
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
|
|
|
# record the nodes with ongoing write through
|
|
self.ongoing_write_through = {}
|
|
# record the node segments with ongoing load back
|
|
self.ongoing_load_back = {}
|
|
# record the ongoing prefetch requests
|
|
self.ongoing_prefetch = {}
|
|
self.ongoing_backup = {}
|
|
# todo: dynamically adjust the threshold
|
|
self.write_through_threshold = (
|
|
1 if hicache_write_policy == "write_through" else 2
|
|
)
|
|
self.load_back_threshold = 10
|
|
|
|
super().__init__(
|
|
req_to_token_pool,
|
|
token_to_kv_pool_allocator,
|
|
page_size,
|
|
disable=False,
|
|
eviction_policy=eviction_policy,
|
|
is_eagle=is_eagle,
|
|
)
|
|
|
|
def _parse_storage_backend_extra_config(
|
|
self, storage_backend_extra_config: Optional[str]
|
|
):
|
|
"""
|
|
Parse storage backend extra config JSON and extract specific parameters.
|
|
|
|
Args:
|
|
storage_backend_extra_config: JSON string containing extra configuration
|
|
|
|
Returns:
|
|
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
|
|
"""
|
|
# Parse extra config JSON if provided
|
|
extra_config = {}
|
|
if storage_backend_extra_config:
|
|
try:
|
|
extra_config = json.loads(storage_backend_extra_config)
|
|
except Exception as e:
|
|
logger.error(f"Invalid backend extra config JSON: {e}")
|
|
raise e
|
|
|
|
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
|
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
|
prefetch_timeout_per_ki_token = extra_config.pop(
|
|
"prefetch_timeout_per_ki_token", 0.25
|
|
) # seconds per 1024 tokens
|
|
hicache_storage_pass_prefix_keys = extra_config.pop(
|
|
"hicache_storage_pass_prefix_keys", False
|
|
)
|
|
|
|
if not isinstance(prefetch_threshold, int):
|
|
raise ValueError(
|
|
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
|
)
|
|
if not isinstance(prefetch_timeout_base, (int, float)):
|
|
raise ValueError(
|
|
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
|
)
|
|
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
|
raise ValueError(
|
|
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
|
)
|
|
|
|
return (
|
|
extra_config,
|
|
prefetch_threshold,
|
|
float(prefetch_timeout_base),
|
|
float(prefetch_timeout_per_ki_token),
|
|
hicache_storage_pass_prefix_keys,
|
|
)
|
|
|
|
def reset(self):
|
|
TreeNode.counter = 0
|
|
self.cache_controller.reset()
|
|
self.token_to_kv_pool_host.clear()
|
|
super().reset()
|
|
|
|
def get_height(self, node: TreeNode):
|
|
height = 0
|
|
while node != self.root_node:
|
|
node = node.parent
|
|
height += 1
|
|
return height
|
|
|
|
def clear_storage_backend(self) -> bool:
|
|
if self.enable_storage:
|
|
try:
|
|
# Check if the storage backend has a clear method (for nixl backends)
|
|
if hasattr(self.cache_controller.storage_backend, "clear"):
|
|
self.cache_controller.storage_backend.clear()
|
|
logger.info(
|
|
"Hierarchical cache storage backend cleared successfully!"
|
|
)
|
|
return True
|
|
else:
|
|
logger.warning(
|
|
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
|
return False
|
|
else:
|
|
logger.warning("Hierarchical cache storage backend is not enabled.")
|
|
return False
|
|
|
|
def write_backup(self, node: TreeNode, write_back=False):
|
|
host_indices = self.cache_controller.write(
|
|
device_indices=node.value,
|
|
node_id=node.id,
|
|
)
|
|
if host_indices is None:
|
|
self.evict_host(len(node.value))
|
|
host_indices = self.cache_controller.write(
|
|
device_indices=node.value,
|
|
node_id=node.id,
|
|
)
|
|
if host_indices is not None:
|
|
node.host_value = host_indices
|
|
assert len(node.host_value) > 0
|
|
self.ongoing_write_through[node.id] = node
|
|
if not write_back:
|
|
# no need to lock nodes if write back
|
|
self.inc_lock_ref(node)
|
|
else:
|
|
return 0
|
|
|
|
return len(host_indices)
|
|
|
|
def write_backup_storage(self, node: TreeNode):
|
|
prefix_keys = (
|
|
node.get_prefix_hash_values(node.parent)
|
|
if self.hicache_storage_pass_prefix_keys
|
|
else None
|
|
)
|
|
|
|
operation_id = self.cache_controller.write_storage(
|
|
node.host_value, node.key, node.hash_value, prefix_keys
|
|
)
|
|
self.ongoing_backup[operation_id] = node
|
|
node.protect_host()
|
|
|
|
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
|
# skip the hit count update for chunked requests
|
|
if self.cache_controller.write_policy == "write_back" or chunked:
|
|
return
|
|
node.hit_count += 1
|
|
|
|
if not node.backuped:
|
|
if node.hit_count >= self.write_through_threshold:
|
|
# write to host if the node is not backuped
|
|
self.write_backup(node)
|
|
|
|
def writing_check(self, write_back=False):
|
|
if write_back:
|
|
# blocking till all write back complete
|
|
while len(self.ongoing_write_through) > 0:
|
|
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
|
finish_event.synchronize()
|
|
for ack_id in ack_list:
|
|
del self.ongoing_write_through[ack_id]
|
|
self.cache_controller.ack_write_queue.clear()
|
|
assert len(self.ongoing_write_through) == 0
|
|
return
|
|
|
|
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
|
if len(self.ongoing_write_through) == 0:
|
|
return
|
|
|
|
finish_count = 0
|
|
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
|
if not finish_event.query():
|
|
break
|
|
finish_count += 1
|
|
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
|
if self.tp_world_size > 1:
|
|
# synchronize TP workers to make the same update to radix cache
|
|
torch.distributed.all_reduce(
|
|
queue_size,
|
|
op=torch.distributed.ReduceOp.MIN,
|
|
group=self.tp_group,
|
|
)
|
|
|
|
finish_count = int(queue_size.item())
|
|
while finish_count > 0:
|
|
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
|
finish_event.synchronize()
|
|
for ack_id in ack_list:
|
|
backuped_node = self.ongoing_write_through.pop(ack_id)
|
|
self.dec_lock_ref(backuped_node)
|
|
if self.enable_storage:
|
|
self.write_backup_storage(backuped_node)
|
|
finish_count -= 1
|
|
|
|
def loading_check(self):
|
|
finish_count = 0
|
|
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
|
if not finish_event.query():
|
|
# the KV cache loading is still ongoing
|
|
break
|
|
finish_count += 1
|
|
# no need to sync across TP workers as batch forwarding is synced
|
|
for ack_id in ack_list:
|
|
end_node = self.ongoing_load_back.pop(ack_id)
|
|
self.dec_lock_ref(end_node)
|
|
|
|
# ACK until all events are processed
|
|
del self.cache_controller.ack_load_queue[:finish_count]
|
|
|
|
def evictable_size(self):
|
|
return self.evictable_size_
|
|
|
|
def evict(self, num_tokens: int):
|
|
leaves = self._collect_leaves_device()
|
|
eviction_heap = [
|
|
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
|
]
|
|
heapq.heapify(eviction_heap)
|
|
|
|
num_evicted = 0
|
|
write_back_nodes = []
|
|
while num_evicted < num_tokens and len(eviction_heap):
|
|
_priority, x = heapq.heappop(eviction_heap)
|
|
|
|
if x.lock_ref > 0:
|
|
continue
|
|
|
|
if not x.backuped:
|
|
if self.cache_controller.write_policy == "write_back":
|
|
# write to host if the node is not backuped
|
|
num_evicted += self.write_backup(x, write_back=True)
|
|
write_back_nodes.append(x)
|
|
else:
|
|
num_evicted += self._evict_regular(x)
|
|
else:
|
|
num_evicted += self._evict_backuped(x)
|
|
|
|
for child in x.parent.children.values():
|
|
if child in write_back_nodes:
|
|
continue
|
|
if not child.evicted:
|
|
break
|
|
else:
|
|
# all children are evicted or no children
|
|
new_priority = self.eviction_strategy.get_priority(x.parent)
|
|
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
|
|
|
if self.cache_controller.write_policy == "write_back":
|
|
self.writing_check(write_back=True)
|
|
for node in write_back_nodes:
|
|
assert node.backuped
|
|
self._evict_backuped(node)
|
|
|
|
def _evict_backuped(self, node: TreeNode):
|
|
# evict a node already written to host
|
|
num_evicted = self.cache_controller.evict_device(node.value)
|
|
assert num_evicted > 0
|
|
self.evictable_size_ -= num_evicted
|
|
node.value = None
|
|
return num_evicted
|
|
|
|
def _evict_regular(self, node: TreeNode):
|
|
# evict a node not initiated write to host
|
|
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
|
num_evicted = len(node.value)
|
|
self._delete_leaf(node)
|
|
return num_evicted
|
|
|
|
def evict_host(self, num_tokens: int):
|
|
leaves = self._collect_leaves()
|
|
eviction_heap = [
|
|
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
|
]
|
|
heapq.heapify(eviction_heap)
|
|
|
|
num_evicted = 0
|
|
while num_evicted < num_tokens and len(eviction_heap):
|
|
_priority, x = heapq.heappop(eviction_heap)
|
|
if x == self.root_node:
|
|
break
|
|
# only evict the host value of evicted nodes
|
|
if not x.evicted:
|
|
continue
|
|
|
|
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
|
if x.host_ref_counter > 0:
|
|
continue
|
|
|
|
num_evicted += self.cache_controller.evict_host(x.host_value)
|
|
|
|
for k, v in x.parent.children.items():
|
|
if v == x:
|
|
break
|
|
del x.parent.children[k]
|
|
|
|
if len(x.parent.children) == 0 and x.parent.evicted:
|
|
new_priority = self.eviction_strategy.get_priority(x.parent)
|
|
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
|
|
|
def load_back(
|
|
self, node: TreeNode, mem_quota: Optional[int] = None
|
|
) -> Optional[torch.Tensor]:
|
|
# todo: more loading policies
|
|
|
|
last_hit_node = node
|
|
nodes_to_load = []
|
|
while node.evicted:
|
|
assert (
|
|
node.backuped
|
|
), "No backup available on evicted nodes, should not happen"
|
|
nodes_to_load.insert(0, node)
|
|
node = node.parent
|
|
else:
|
|
ancester_node = node
|
|
|
|
# protect the ancestor nodes from eviction
|
|
delta = self.inc_lock_ref(ancester_node)
|
|
|
|
# load it all or not at all
|
|
host_indices = torch.cat([n.host_value for n in nodes_to_load])
|
|
if len(host_indices) < self.load_back_threshold or (
|
|
len(host_indices) > mem_quota + delta if mem_quota is not None else False
|
|
):
|
|
# skip loading back if the total size is too small or exceeding the memory quota
|
|
self.dec_lock_ref(ancester_node)
|
|
return None
|
|
|
|
device_indices = self.cache_controller.load(
|
|
host_indices=host_indices, node_id=last_hit_node.id
|
|
)
|
|
if device_indices is None:
|
|
self.evict(len(host_indices))
|
|
device_indices = self.cache_controller.load(
|
|
host_indices=host_indices, node_id=last_hit_node.id
|
|
)
|
|
self.dec_lock_ref(ancester_node)
|
|
if device_indices is None:
|
|
# no sufficient GPU memory to load back KV caches
|
|
return None
|
|
|
|
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
|
offset = 0
|
|
for node in nodes_to_load:
|
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
|
offset += len(node.host_value)
|
|
self.evictable_size_ += len(device_indices)
|
|
self.inc_lock_ref(last_hit_node)
|
|
|
|
return device_indices
|
|
|
|
def init_load_back(
|
|
self,
|
|
last_node: TreeNode,
|
|
host_hit_length: int,
|
|
mem_quota: Optional[int] = None,
|
|
):
|
|
_ = host_hit_length # unused, but kept for compatibility
|
|
if last_node.evicted:
|
|
loading_values = self.load_back(last_node, mem_quota)
|
|
if loading_values is not None:
|
|
logger.debug(
|
|
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
|
)
|
|
return loading_values, last_node
|
|
|
|
while last_node.evicted:
|
|
last_node = last_node.parent
|
|
|
|
return (
|
|
torch.empty((0,), dtype=torch.int64, device=self.device),
|
|
last_node,
|
|
)
|
|
|
|
def ready_to_load_host_cache(self) -> int:
|
|
"""
|
|
Notify the cache controller to start the KV cache loading.
|
|
Return the consumer index for the schedule batch manager to track.
|
|
"""
|
|
return self.cache_controller.start_loading()
|
|
|
|
def check_hicache_events(self):
|
|
self.writing_check()
|
|
self.loading_check()
|
|
if self.enable_storage:
|
|
self.drain_storage_control_queues()
|
|
if self.enable_storage_metrics:
|
|
self.metrics_collector.log_storage_metrics(
|
|
self.cache_controller.storage_backend.get_stats()
|
|
)
|
|
|
|
def drain_storage_control_queues(self):
|
|
"""
|
|
Combine prefetch revoke, backup ack, and host mem release checks
|
|
to minimize TP synchronization and Python overhead.
|
|
"""
|
|
cc = self.cache_controller
|
|
|
|
qsizes = torch.tensor(
|
|
[
|
|
cc.prefetch_revoke_queue.qsize(),
|
|
cc.ack_backup_queue.qsize(),
|
|
cc.host_mem_release_queue.qsize(),
|
|
],
|
|
dtype=torch.int,
|
|
)
|
|
if self.tp_world_size > 1:
|
|
torch.distributed.all_reduce(
|
|
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
|
)
|
|
|
|
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
|
|
|
# process prefetch revokes
|
|
for _ in range(n_revoke):
|
|
req_id = cc.prefetch_revoke_queue.get()
|
|
info = self.ongoing_prefetch.pop(req_id, None)
|
|
if info is not None:
|
|
last_host_node, token_ids, _, _ = info
|
|
last_host_node.release_host()
|
|
cc.prefetch_tokens_occupied -= len(token_ids)
|
|
# else: the revoked operation already got terminated, nothing to do
|
|
|
|
# process backup acks
|
|
for _ in range(n_backup):
|
|
operation = cc.ack_backup_queue.get()
|
|
ack_id = operation.id
|
|
entry = self.ongoing_backup.pop(ack_id, None)
|
|
if entry is not None:
|
|
entry.release_host()
|
|
if self.enable_storage_metrics:
|
|
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
|
|
|
# release host memory
|
|
host_indices_list = []
|
|
for _ in range(n_release):
|
|
host_indices_list.append(cc.host_mem_release_queue.get())
|
|
if host_indices_list:
|
|
host_indices = torch.cat(host_indices_list, dim=0)
|
|
cc.mem_pool_host.free(host_indices)
|
|
|
|
# Timeout is linearly increasing with the number of pages
|
|
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
|
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
|
return (
|
|
time.monotonic() - operation.start_time
|
|
> self.prefetch_timeout_base
|
|
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
|
)
|
|
|
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
|
can_terminate = True
|
|
|
|
if self.prefetch_stop_policy == "best_effort":
|
|
return can_terminate
|
|
|
|
if len(operation.hash_value) == 0:
|
|
completed = False
|
|
else:
|
|
completed = (
|
|
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
|
)
|
|
|
|
if self.prefetch_stop_policy == "wait_complete":
|
|
can_terminate = completed
|
|
elif self.prefetch_stop_policy == "timeout":
|
|
can_terminate = completed or self.is_prefetch_timeout(operation)
|
|
else:
|
|
# unknown prefetch stop policy, just return True
|
|
return True
|
|
|
|
operation_terminated = operation.is_terminated()
|
|
if self.tp_world_size > 1:
|
|
states = torch.tensor(
|
|
[1 - int(can_terminate), int(operation_terminated)],
|
|
dtype=torch.int,
|
|
)
|
|
torch.distributed.all_reduce(
|
|
states,
|
|
op=torch.distributed.ReduceOp.MAX,
|
|
group=self.tp_group,
|
|
)
|
|
can_terminate = states[0].item() == 0
|
|
operation_terminated = states[1].item() == 1
|
|
# the operation should be terminated if it is already terminated on any TP worker
|
|
# or it meets the termination condition on all TP workers
|
|
can_terminate = can_terminate or operation_terminated
|
|
return can_terminate
|
|
|
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
|
if req_id not in self.ongoing_prefetch:
|
|
# there is no ongoing prefetch for this request or it has been revoked
|
|
return True
|
|
|
|
# todo: more policies for prefetch progress such as timeout
|
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
|
req_id
|
|
]
|
|
|
|
if operation.host_indices is None:
|
|
# prefetch has not been issued due to insufficient host memory
|
|
return True
|
|
|
|
if not self.can_terminate_prefetch(operation):
|
|
return False
|
|
|
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
|
operation
|
|
)
|
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
|
|
|
min_completed_tokens = completed_tokens
|
|
if self.tp_world_size > 1:
|
|
# synchrnoize TP workers to make the same update to hiradix cache
|
|
completed_tokens_tensor = torch.tensor(
|
|
min_completed_tokens, dtype=torch.int
|
|
)
|
|
torch.distributed.all_reduce(
|
|
completed_tokens_tensor,
|
|
op=torch.distributed.ReduceOp.MIN,
|
|
group=self.tp_group,
|
|
)
|
|
min_completed_tokens = completed_tokens_tensor.item()
|
|
fetched_token_ids = token_ids[:min_completed_tokens]
|
|
written_indices = host_indices[:min_completed_tokens]
|
|
matched_length = self._insert_helper_host(
|
|
last_host_node,
|
|
RadixKey(
|
|
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
|
),
|
|
written_indices,
|
|
hash_value[: min_completed_tokens // self.page_size],
|
|
)
|
|
|
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
|
self.cache_controller.append_host_mem_release(
|
|
host_indices[min_completed_tokens:completed_tokens]
|
|
)
|
|
last_host_node.release_host()
|
|
del self.ongoing_prefetch[req_id]
|
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
|
|
|
if self.enable_storage_metrics:
|
|
self.metrics_collector.log_prefetched_tokens(
|
|
min_completed_tokens - matched_length
|
|
)
|
|
|
|
return True
|
|
|
|
def match_prefix(self, key: RadixKey, **kwargs):
|
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
|
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
if self.disable or len(key) == 0:
|
|
return MatchResult(
|
|
device_indices=empty_value,
|
|
last_device_node=self.root_node,
|
|
last_host_node=self.root_node,
|
|
host_hit_length=0,
|
|
)
|
|
|
|
if self.page_size != 1:
|
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
|
key = key[:page_aligned_len]
|
|
|
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
|
if value:
|
|
value = torch.cat(value)
|
|
else:
|
|
value = empty_value
|
|
|
|
host_hit_length = 0
|
|
last_host_node = last_node
|
|
while last_node.evicted:
|
|
host_hit_length += len(last_node.host_value)
|
|
last_node = last_node.parent
|
|
while not last_host_node.backuped:
|
|
last_host_node = last_host_node.parent
|
|
|
|
return MatchResult(
|
|
device_indices=value,
|
|
last_device_node=last_node,
|
|
last_host_node=last_host_node,
|
|
host_hit_length=host_hit_length,
|
|
)
|
|
|
|
def prefetch_from_storage(
|
|
self,
|
|
req_id: str,
|
|
last_host_node: TreeNode,
|
|
new_input_tokens: List[int],
|
|
last_hash: Optional[str] = None,
|
|
prefix_keys: Optional[List[str]] = None,
|
|
):
|
|
# align the number of fetching tokens to the page size
|
|
prefetch_length = len(new_input_tokens) - (
|
|
len(new_input_tokens) % self.page_size
|
|
)
|
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
|
if (
|
|
not self.enable_storage
|
|
or prefetch_length < self.prefetch_threshold
|
|
or self.cache_controller.prefetch_rate_limited()
|
|
):
|
|
return
|
|
|
|
last_host_node.protect_host()
|
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
|
if host_indices is None:
|
|
self.evict_host(prefetch_length)
|
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
|
if host_indices is None:
|
|
last_host_node.release_host()
|
|
# no sufficient host memory for prefetch
|
|
return
|
|
operation = self.cache_controller.prefetch(
|
|
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
|
)
|
|
self.ongoing_prefetch[req_id] = (
|
|
last_host_node,
|
|
new_input_tokens,
|
|
host_indices,
|
|
operation,
|
|
)
|
|
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
|
|
|
def _insert_helper_host(
|
|
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
|
):
|
|
node.last_access_time = time.monotonic()
|
|
if len(key) == 0:
|
|
return 0
|
|
|
|
child_key = self.get_child_key_fn(key)
|
|
|
|
matched_length = 0
|
|
while len(key) > 0 and child_key in node.children.keys():
|
|
node = node.children[child_key]
|
|
node.last_access_time = time.monotonic()
|
|
prefix_len = self.key_match_fn(node.key, key)
|
|
key = key[prefix_len:]
|
|
host_value = host_value[prefix_len:]
|
|
hash_value = hash_value[prefix_len // self.page_size :]
|
|
matched_length += prefix_len
|
|
|
|
if prefix_len < len(node.key):
|
|
new_node = self._split_node(node.key, node, prefix_len)
|
|
node = new_node
|
|
|
|
if len(key):
|
|
child_key = self.get_child_key_fn(key)
|
|
|
|
if len(key):
|
|
new_node = TreeNode()
|
|
new_node.parent = node
|
|
new_node.key = key
|
|
new_node.value = None
|
|
new_node.host_value = host_value
|
|
new_node.hash_value = hash_value
|
|
node.children[child_key] = new_node
|
|
return matched_length
|
|
|
|
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
|
node.last_access_time = time.monotonic()
|
|
child_key = self.get_child_key_fn(key)
|
|
value = []
|
|
|
|
while len(key) > 0 and child_key in node.children.keys():
|
|
child = node.children[child_key]
|
|
child.last_access_time = time.monotonic()
|
|
prefix_len = self.key_match_fn(child.key, key)
|
|
if prefix_len < len(child.key):
|
|
new_node = self._split_node(child.key, child, prefix_len)
|
|
if not new_node.evicted:
|
|
value.append(new_node.value)
|
|
node = new_node
|
|
break
|
|
else:
|
|
if not child.evicted:
|
|
value.append(child.value)
|
|
node = child
|
|
key = key[prefix_len:]
|
|
|
|
if len(key):
|
|
child_key = self.get_child_key_fn(key)
|
|
|
|
return value, node
|
|
|
|
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
|
# child node split into new_node -> child
|
|
new_node = TreeNode()
|
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
|
new_node.parent = child.parent
|
|
new_node.lock_ref = child.lock_ref
|
|
new_node.key = child.key[:split_len]
|
|
new_node.hit_count = child.hit_count
|
|
|
|
# split value and host value if exists
|
|
if child.evicted:
|
|
new_node.value = None
|
|
else:
|
|
new_node.value = child.value[:split_len]
|
|
child.value = child.value[split_len:]
|
|
if child.backuped:
|
|
new_node.host_value = child.host_value[:split_len]
|
|
child.host_value = child.host_value[split_len:]
|
|
|
|
if child.hash_value:
|
|
new_node.hash_value = child.hash_value[: split_len // self.page_size]
|
|
child.hash_value = child.hash_value[split_len // self.page_size :]
|
|
child.parent = new_node
|
|
child.key = child.key[split_len:]
|
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
|
return new_node
|
|
|
|
def insert(self, key: RadixKey, value=None, chunked=False):
|
|
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
|
|
if len(key) == 0:
|
|
return 0
|
|
|
|
if self.is_eagle and value is not None:
|
|
# Make sure the value len equal to the EAGLE bigram key len
|
|
value = value[: len(key)]
|
|
|
|
node = self.root_node
|
|
child_key = self.get_child_key_fn(key)
|
|
total_prefix_length = 0
|
|
|
|
while len(key) > 0 and child_key in node.children.keys():
|
|
node = node.children[child_key]
|
|
node.last_access_time = time.monotonic()
|
|
prefix_len = self.key_match_fn(node.key, key)
|
|
|
|
if prefix_len == len(node.key):
|
|
if node.evicted:
|
|
# change the reference if the node is evicted
|
|
# this often happens in the case of KV cache recomputation
|
|
node.value = value[:prefix_len]
|
|
self.evictable_size_ += len(node.value)
|
|
else:
|
|
self._inc_hit_count(node, chunked)
|
|
total_prefix_length += prefix_len
|
|
else:
|
|
# partial match, split the node
|
|
new_node = self._split_node(node.key, node, prefix_len)
|
|
if new_node.evicted:
|
|
new_node.value = value[:prefix_len]
|
|
self.evictable_size_ += len(new_node.value)
|
|
else:
|
|
self._inc_hit_count(new_node, chunked)
|
|
total_prefix_length += prefix_len
|
|
node = new_node
|
|
|
|
key = key[prefix_len:]
|
|
value = value[prefix_len:]
|
|
|
|
if len(key):
|
|
child_key = self.get_child_key_fn(key)
|
|
|
|
if len(key):
|
|
new_node = TreeNode()
|
|
new_node.parent = node
|
|
new_node.key = key
|
|
new_node.value = value
|
|
node.children[child_key] = new_node
|
|
self.evictable_size_ += len(value)
|
|
|
|
if self.enable_storage:
|
|
last_hash = node.get_last_hash_value()
|
|
assert (node == self.root_node) or (
|
|
last_hash is not None
|
|
), "Parent node must have a hash value with storage enabled"
|
|
new_node.hash_value = []
|
|
for idx in range(0, len(key), self.page_size):
|
|
new_node.hash_value.append(
|
|
self.cache_controller.get_hash_str(
|
|
key.token_ids[idx : idx + self.page_size],
|
|
prior_hash=last_hash,
|
|
)
|
|
)
|
|
last_hash = new_node.hash_value[-1]
|
|
|
|
if self.cache_controller.write_policy != "write_back":
|
|
self._inc_hit_count(new_node, chunked)
|
|
return total_prefix_length
|
|
|
|
def _collect_leaves_device(self):
|
|
def is_leaf(node):
|
|
if node.evicted:
|
|
return False
|
|
if node == self.root_node:
|
|
return False
|
|
if len(node.children) == 0:
|
|
return True
|
|
for child in node.children.values():
|
|
if not child.evicted:
|
|
return False
|
|
return True
|
|
|
|
ret_list = []
|
|
stack = [self.root_node]
|
|
while stack:
|
|
cur_node = stack.pop()
|
|
if is_leaf(cur_node):
|
|
ret_list.append(cur_node)
|
|
else:
|
|
for cur_child in cur_node.children.values():
|
|
if not cur_child.evicted:
|
|
stack.append(cur_child)
|
|
return ret_list
|
|
|
|
def release_aborted_request(self, rid: str):
|
|
if rid not in self.ongoing_prefetch:
|
|
return
|
|
|
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
|
if operation.host_indices is None:
|
|
return
|
|
|
|
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
|
if self.tp_world_size > 1:
|
|
torch.distributed.barrier(group=self.tp_group)
|
|
last_host_node.release_host()
|
|
del self.ongoing_prefetch[rid]
|
|
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|