Hicache Storage Layer Prototype (#7704)
This commit is contained in:
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -159,6 +161,57 @@ class TransferBuffer:
|
||||
self.buffers.queue.clear()
|
||||
|
||||
|
||||
class StorageOperation:
|
||||
counter = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
self.host_indices = host_indices
|
||||
self.token_ids = token_ids
|
||||
self.last_hash = last_hash
|
||||
self.completed_tokens = 0
|
||||
self.hash_value = []
|
||||
|
||||
self.id = StorageOperation.counter
|
||||
StorageOperation.counter += 1
|
||||
|
||||
def __lt__(self, other: "StorageOperation"):
|
||||
return self.id < other.id
|
||||
|
||||
|
||||
class PrefetchOperation(StorageOperation):
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
|
||||
self._done_flag = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
super().__init__(host_indices, token_ids, last_hash)
|
||||
|
||||
def increment(self, num_tokens: int):
|
||||
with self._lock:
|
||||
if self._done_flag:
|
||||
return
|
||||
self.completed_tokens += num_tokens
|
||||
|
||||
def mark_done(self):
|
||||
with self._lock:
|
||||
self._done_flag = True
|
||||
|
||||
def is_done(self) -> bool:
|
||||
return self._done_flag
|
||||
|
||||
|
||||
class HiCacheController:
|
||||
|
||||
def __init__(
|
||||
@@ -169,6 +222,8 @@ class HiCacheController:
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
io_backend: str = "",
|
||||
storage_backend: Optional[str] = None,
|
||||
prefetch_threshold: int = 256,
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
@@ -186,6 +241,19 @@ class HiCacheController:
|
||||
else:
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.enable_storage = False
|
||||
# todo: move backend initialization to storage backend module
|
||||
if storage_backend is not None:
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = prefetch_threshold
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||
@@ -218,9 +286,26 @@ class HiCacheController:
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread = threading.Thread(
|
||||
target=self.prefetch_thread_func, daemon=True
|
||||
)
|
||||
self.backup_thread = threading.Thread(
|
||||
target=self.backup_thread_func, daemon=True
|
||||
)
|
||||
self.prefetch_queue = Queue()
|
||||
self.backup_queue = Queue()
|
||||
|
||||
self.prefetch_revoke_queue = Queue()
|
||||
self.ack_backup_queue = Queue()
|
||||
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
def reset(self):
|
||||
self.stop_event.set()
|
||||
self.write_thread.join()
|
||||
@@ -232,6 +317,13 @@ class HiCacheController:
|
||||
self.load_buffer.clear()
|
||||
self.ack_write_queue.queue.clear()
|
||||
self.ack_load_queue.queue.clear()
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread.join()
|
||||
self.backup_thread.join()
|
||||
self.prefetch_queue.queue.clear()
|
||||
self.backup_queue.queue.clear()
|
||||
self.prefetch_revoke_queue.queue.clear()
|
||||
self.ack_backup_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_direct, daemon=True
|
||||
@@ -243,6 +335,16 @@ class HiCacheController:
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread = threading.Thread(
|
||||
target=self.prefetch_thread_func, daemon=True
|
||||
)
|
||||
self.backup_thread = threading.Thread(
|
||||
target=self.backup_thread_func, daemon=True
|
||||
)
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
def write(
|
||||
self,
|
||||
device_indices: torch.Tensor,
|
||||
@@ -383,3 +485,142 @@ class HiCacheController:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
|
||||
def prefetch(
|
||||
self,
|
||||
request_id: str,
|
||||
host_indices: torch.Tensor,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Prefetch KV caches from storage backend to host memory.
|
||||
"""
|
||||
operation = PrefetchOperation(
|
||||
request_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
self.prefetch_queue.put(operation)
|
||||
return operation
|
||||
|
||||
def terminate_prefetch(self, operation):
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def prefetch_io_aux_func(self):
|
||||
"""
|
||||
Auxiliary function conducting IO operations for prefetching.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
for h in operation.hash_value:
|
||||
page_data = self.storage_backend.get(h)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||
)
|
||||
break
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[operation.completed_tokens],
|
||||
page_data,
|
||||
)
|
||||
operation.increment(self.page_size)
|
||||
if operation.is_done():
|
||||
# operation terminated by controller, release pre-allocated memory
|
||||
self.mem_pool_host.free(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
break
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
Manage prefetching operations from storage backend to host memory.
|
||||
"""
|
||||
self.prefetch_buffer = Queue()
|
||||
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
||||
try:
|
||||
operation = self.prefetch_queue.get(block=True, timeout=1)
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
if self.storage_backend.exists(last_hash):
|
||||
storage_hit_count += self.page_size
|
||||
hash_value.append(last_hash)
|
||||
remaining_tokens -= self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
if storage_hit_count < self.prefetch_threshold:
|
||||
# not to prefetch if not enough benefits
|
||||
self.prefetch_revoke_queue.put(operation.request_id)
|
||||
else:
|
||||
operation.hash_value = hash_value
|
||||
logger.debug(
|
||||
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
||||
)
|
||||
self.prefetch_buffer.put(operation)
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def write_storage(
|
||||
self,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Write KV caches from host memory to storage backend.
|
||||
"""
|
||||
operation = StorageOperation(host_indices, token_ids, last_hash)
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def backup_thread_func(self):
|
||||
"""
|
||||
Manage backup operations from host memory to storage backend.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.backup_queue.get(block=True, timeout=1)
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_backup[i : i + self.page_size], last_hash
|
||||
)
|
||||
# todo, handle failures in storage backend
|
||||
self.storage_backend.set(
|
||||
last_hash,
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
),
|
||||
)
|
||||
operation.completed_tokens += self.page_size
|
||||
operation.hash_value.append(last_hash)
|
||||
|
||||
self.ack_backup_queue.put((operation.id, operation.hash_value))
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
@@ -262,6 +262,7 @@ class Scheduler(
|
||||
)
|
||||
self.gpu_id = gpu_id
|
||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
||||
self.page_size = server_args.page_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||
@@ -614,6 +615,7 @@ class Scheduler(
|
||||
== "fa3" # hot fix for incompatibility
|
||||
else server_args.hicache_io_backend
|
||||
),
|
||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
@@ -1258,6 +1260,15 @@ class Scheduler(
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
if self.enable_hicache_storage:
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
last_hash = req.last_host_node.get_last_hash_value()
|
||||
matched_len = len(req.prefix_indices) + req.host_hit_length
|
||||
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
||||
new_input_tokens = req.fill_ids[matched_len:]
|
||||
self.tree_cache.prefetch_from_storage(
|
||||
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
@@ -1731,6 +1742,9 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if self.enable_hicache_storage:
|
||||
self.tree_cache.check_prefetch_progress(req.rid)
|
||||
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||
|
||||
|
||||
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
if prior_hash:
|
||||
hasher.update(bytes.fromhex(prior_hash))
|
||||
|
||||
for t in token_ids:
|
||||
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class HiCacheStorage(ABC):
|
||||
"""
|
||||
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||
"""
|
||||
|
||||
# todo, translate tensor object access for different TP ranks
|
||||
# potentially pass model and TP configs into storage backend
|
||||
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Retrieve the value associated with the given key.
|
||||
Returns None if the key does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_get(
|
||||
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
|
||||
) -> List[torch.Tensor | None]:
|
||||
"""
|
||||
Retrieve values for multiple keys.
|
||||
Returns a list of tensors or None for each key.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key, value) -> bool:
|
||||
"""
|
||||
Store the value associated with the given key.
|
||||
Returns True if the operation was successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
"""
|
||||
Store multiple key-value pairs.
|
||||
Returns True if all operations were successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if the key exists in the storage.
|
||||
Returns True if the key exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiCacheFile(HiCacheStorage):
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache"):
|
||||
self.file_path = file_path
|
||||
if not os.path.exists(self.file_path):
|
||||
os.makedirs(self.file_path)
|
||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
# todo: fixing the target_location logic to enable in-place loading
|
||||
loaded_tensor = torch.load(tensor_path)
|
||||
if isinstance(loaded_tensor, torch.Tensor):
|
||||
return loaded_tensor
|
||||
else:
|
||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor]] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
return [
|
||||
self.get(key, target_location)
|
||||
for key, target_location in zip(
|
||||
keys, target_locations or [None] * len(keys)
|
||||
)
|
||||
]
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
if self.exists(key):
|
||||
logger.debug(f"Key {key} already exists. Skipped.")
|
||||
return True
|
||||
try:
|
||||
torch.save(value, tensor_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save tensor {key}: {e}")
|
||||
return False
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
for key, value in zip(keys, values):
|
||||
if not self.set(key, value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
return os.path.exists(tensor_path)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
os.remove(tensor_path)
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
||||
return
|
||||
|
||||
def clear(self) -> None:
|
||||
try:
|
||||
for filename in os.listdir(self.file_path):
|
||||
file_path = os.path.join(self.file_path, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info("Cleared all entries in HiCacheFile storage.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
||||
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
hicache_io_backend: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
):
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.enable_storage = hicache_storage_backend is not None
|
||||
# todo: customizable storage prefetch threshold
|
||||
self.prefetch_threshold = 256
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
|
||||
load_cache_event=self.load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back = {}
|
||||
# record the ongoing prefetch requests
|
||||
self.ongoing_prefetch = {}
|
||||
self.ongoing_backup = {}
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.write_through_threshold_storage = 3
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return len(host_indices)
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.parent.get_last_hash_value()
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
|
||||
def inc_hit_count(self, node: TreeNode):
|
||||
if node.backuped or self.cache_controller.write_policy == "write_back":
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
return
|
||||
node.hit_count += 1
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
self.write_backup(node)
|
||||
node.hit_count = 0
|
||||
|
||||
if not node.backuped:
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
# write to host if the node is not backuped
|
||||
self.write_backup(node)
|
||||
else:
|
||||
if (
|
||||
self.enable_storage
|
||||
and (not node.backuped_storage)
|
||||
and node.hit_count >= self.write_through_threshold_storage
|
||||
):
|
||||
# if the node is backuped on host memory but not on storage
|
||||
self.write_backup_storage(node)
|
||||
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
|
||||
if not x.evicted:
|
||||
continue
|
||||
|
||||
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
||||
if x.host_ref_counter > 0:
|
||||
continue
|
||||
|
||||
num_evicted += self.cache_controller.evict_host(x.host_value)
|
||||
|
||||
for k, v in x.parent.children.items():
|
||||
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
if self.enable_storage:
|
||||
self.check_revoked_prefetch()
|
||||
self.check_backup_progress()
|
||||
|
||||
def check_revoked_prefetch(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node.release_host()
|
||||
self.cache_controller.mem_pool_host.free(host_indices)
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def check_backup_progress(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
||||
self.ongoing_backup[ack_id].hash_value = hash_value
|
||||
self.ongoing_backup[ack_id].release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
def check_prefetch_progress(self, req_id: str):
|
||||
if req_id not in self.ongoing_prefetch:
|
||||
# there is no ongoing prefetch for this request or it has been revoked
|
||||
return
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||
req_id
|
||||
]
|
||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||
operation
|
||||
)
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
min_completed_tokens,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
min_completed_tokens = min_completed_tokens.item()
|
||||
fetched_token_ids = token_ids[:min_completed_tokens]
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
written_indices,
|
||||
hash_value[:min_completed_tokens],
|
||||
)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def prefetch_from_storage(
|
||||
self,
|
||||
req_id: str,
|
||||
last_host_node: TreeNode,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
||||
return
|
||||
|
||||
last_host_node.protect_host()
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
||||
if host_indices is None:
|
||||
self.evict_host(len(new_input_tokens))
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(
|
||||
len(new_input_tokens)
|
||||
)
|
||||
if host_indices is None:
|
||||
last_host_node.release_host()
|
||||
# no sufficient host memory to prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
self.ongoing_prefetch[req_id] = (
|
||||
last_host_node,
|
||||
new_input_tokens,
|
||||
host_indices,
|
||||
operation,
|
||||
)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
matched_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
key = key[prefix_len:]
|
||||
host_value = host_value[prefix_len:]
|
||||
hash_value = hash_value[prefix_len:]
|
||||
matched_length += prefix_len
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = None
|
||||
new_node.host_value = host_value
|
||||
new_node.hash_value = hash_value
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
"""
|
||||
Get a flat data page from the host memory pool.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
"""
|
||||
Set a flat data page to the host memory pool.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# todo, page first memory layout
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def k_buffer(self):
|
||||
return self.kv_buffer[0]
|
||||
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
@@ -55,8 +55,13 @@ class TreeNode:
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# indicating the node is locked to protect from eviction
|
||||
# incremented when the node is referenced by a storage operation
|
||||
self.host_ref_counter = 0
|
||||
# store the host indices of KV cache
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
# store hash values of each pages
|
||||
self.hash_value: Optional[List[str]] = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
@@ -69,6 +74,27 @@ class TreeNode:
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
@property
|
||||
def backuped_storage(self):
|
||||
return self.hash_value is not None and len(self.hash_value) > 0
|
||||
|
||||
def protect_host(self):
|
||||
"""Protect the host value from eviction."""
|
||||
self.host_ref_counter += 1
|
||||
|
||||
def release_host(self):
|
||||
"""Release the host value, allowing it to be evicted."""
|
||||
if self.host_ref_counter > 0:
|
||||
self.host_ref_counter -= 1
|
||||
else:
|
||||
raise RuntimeError("Host reference counter is already zero.")
|
||||
|
||||
def get_last_hash_value(self) -> Optional[str]:
|
||||
"""Returns the hash value of the last page in this node."""
|
||||
if self.hash_value is None or len(self.hash_value) == 0:
|
||||
return None
|
||||
return self.hash_value[-1]
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
@@ -222,6 +222,7 @@ class ServerArgs:
|
||||
hicache_size: int = 0
|
||||
hicache_write_policy: str = "write_through_selective"
|
||||
hicache_io_backend: str = ""
|
||||
hicache_storage_backend: Optional[str] = None
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
disable_shared_experts_fusion: bool = False
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
@@ -1604,6 +1605,13 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_io_backend,
|
||||
help="The IO backend for KV cache transfer between CPU and GPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
choices=["file"], # todo, mooncacke
|
||||
default=ServerArgs.hicache_storage_backend,
|
||||
help="The storage backend for hierarchical KV cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flashinfer-mla-disable-ragged",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user