Hicache Storage Layer Prototype (#7704)
This commit is contained in:
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
if prior_hash:
|
||||
hasher.update(bytes.fromhex(prior_hash))
|
||||
|
||||
for t in token_ids:
|
||||
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class HiCacheStorage(ABC):
|
||||
"""
|
||||
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||
"""
|
||||
|
||||
# todo, translate tensor object access for different TP ranks
|
||||
# potentially pass model and TP configs into storage backend
|
||||
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Retrieve the value associated with the given key.
|
||||
Returns None if the key does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_get(
|
||||
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
|
||||
) -> List[torch.Tensor | None]:
|
||||
"""
|
||||
Retrieve values for multiple keys.
|
||||
Returns a list of tensors or None for each key.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key, value) -> bool:
|
||||
"""
|
||||
Store the value associated with the given key.
|
||||
Returns True if the operation was successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
"""
|
||||
Store multiple key-value pairs.
|
||||
Returns True if all operations were successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if the key exists in the storage.
|
||||
Returns True if the key exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiCacheFile(HiCacheStorage):
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache"):
|
||||
self.file_path = file_path
|
||||
if not os.path.exists(self.file_path):
|
||||
os.makedirs(self.file_path)
|
||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
# todo: fixing the target_location logic to enable in-place loading
|
||||
loaded_tensor = torch.load(tensor_path)
|
||||
if isinstance(loaded_tensor, torch.Tensor):
|
||||
return loaded_tensor
|
||||
else:
|
||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor]] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
return [
|
||||
self.get(key, target_location)
|
||||
for key, target_location in zip(
|
||||
keys, target_locations or [None] * len(keys)
|
||||
)
|
||||
]
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
if self.exists(key):
|
||||
logger.debug(f"Key {key} already exists. Skipped.")
|
||||
return True
|
||||
try:
|
||||
torch.save(value, tensor_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save tensor {key}: {e}")
|
||||
return False
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
for key, value in zip(keys, values):
|
||||
if not self.set(key, value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
return os.path.exists(tensor_path)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
os.remove(tensor_path)
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
||||
return
|
||||
|
||||
def clear(self) -> None:
|
||||
try:
|
||||
for filename in os.listdir(self.file_path):
|
||||
file_path = os.path.join(self.file_path, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info("Cleared all entries in HiCacheFile storage.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
||||
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
hicache_io_backend: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
):
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.enable_storage = hicache_storage_backend is not None
|
||||
# todo: customizable storage prefetch threshold
|
||||
self.prefetch_threshold = 256
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
|
||||
load_cache_event=self.load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back = {}
|
||||
# record the ongoing prefetch requests
|
||||
self.ongoing_prefetch = {}
|
||||
self.ongoing_backup = {}
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.write_through_threshold_storage = 3
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return len(host_indices)
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.parent.get_last_hash_value()
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
|
||||
def inc_hit_count(self, node: TreeNode):
|
||||
if node.backuped or self.cache_controller.write_policy == "write_back":
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
return
|
||||
node.hit_count += 1
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
self.write_backup(node)
|
||||
node.hit_count = 0
|
||||
|
||||
if not node.backuped:
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
# write to host if the node is not backuped
|
||||
self.write_backup(node)
|
||||
else:
|
||||
if (
|
||||
self.enable_storage
|
||||
and (not node.backuped_storage)
|
||||
and node.hit_count >= self.write_through_threshold_storage
|
||||
):
|
||||
# if the node is backuped on host memory but not on storage
|
||||
self.write_backup_storage(node)
|
||||
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
|
||||
if not x.evicted:
|
||||
continue
|
||||
|
||||
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
||||
if x.host_ref_counter > 0:
|
||||
continue
|
||||
|
||||
num_evicted += self.cache_controller.evict_host(x.host_value)
|
||||
|
||||
for k, v in x.parent.children.items():
|
||||
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
if self.enable_storage:
|
||||
self.check_revoked_prefetch()
|
||||
self.check_backup_progress()
|
||||
|
||||
def check_revoked_prefetch(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node.release_host()
|
||||
self.cache_controller.mem_pool_host.free(host_indices)
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def check_backup_progress(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
||||
self.ongoing_backup[ack_id].hash_value = hash_value
|
||||
self.ongoing_backup[ack_id].release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
def check_prefetch_progress(self, req_id: str):
|
||||
if req_id not in self.ongoing_prefetch:
|
||||
# there is no ongoing prefetch for this request or it has been revoked
|
||||
return
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||
req_id
|
||||
]
|
||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||
operation
|
||||
)
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
min_completed_tokens,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
min_completed_tokens = min_completed_tokens.item()
|
||||
fetched_token_ids = token_ids[:min_completed_tokens]
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
written_indices,
|
||||
hash_value[:min_completed_tokens],
|
||||
)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def prefetch_from_storage(
|
||||
self,
|
||||
req_id: str,
|
||||
last_host_node: TreeNode,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
||||
return
|
||||
|
||||
last_host_node.protect_host()
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
||||
if host_indices is None:
|
||||
self.evict_host(len(new_input_tokens))
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(
|
||||
len(new_input_tokens)
|
||||
)
|
||||
if host_indices is None:
|
||||
last_host_node.release_host()
|
||||
# no sufficient host memory to prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
self.ongoing_prefetch[req_id] = (
|
||||
last_host_node,
|
||||
new_input_tokens,
|
||||
host_indices,
|
||||
operation,
|
||||
)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
matched_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
key = key[prefix_len:]
|
||||
host_value = host_value[prefix_len:]
|
||||
hash_value = hash_value[prefix_len:]
|
||||
matched_length += prefix_len
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = None
|
||||
new_node.host_value = host_value
|
||||
new_node.hash_value = hash_value
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
"""
|
||||
Get a flat data page from the host memory pool.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
"""
|
||||
Set a flat data page to the host memory pool.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# todo, page first memory layout
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def k_buffer(self):
|
||||
return self.kv_buffer[0]
|
||||
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
@@ -55,8 +55,13 @@ class TreeNode:
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# indicating the node is locked to protect from eviction
|
||||
# incremented when the node is referenced by a storage operation
|
||||
self.host_ref_counter = 0
|
||||
# store the host indices of KV cache
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
# store hash values of each pages
|
||||
self.hash_value: Optional[List[str]] = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
@@ -69,6 +74,27 @@ class TreeNode:
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
@property
|
||||
def backuped_storage(self):
|
||||
return self.hash_value is not None and len(self.hash_value) > 0
|
||||
|
||||
def protect_host(self):
|
||||
"""Protect the host value from eviction."""
|
||||
self.host_ref_counter += 1
|
||||
|
||||
def release_host(self):
|
||||
"""Release the host value, allowing it to be evicted."""
|
||||
if self.host_ref_counter > 0:
|
||||
self.host_ref_counter -= 1
|
||||
else:
|
||||
raise RuntimeError("Host reference counter is already zero.")
|
||||
|
||||
def get_last_hash_value(self) -> Optional[str]:
|
||||
"""Returns the hash value of the last page in this node."""
|
||||
if self.hash_value is None or len(self.hash_value) == 0:
|
||||
return None
|
||||
return self.hash_value[-1]
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
Reference in New Issue
Block a user