feat(hicache): Support passing prefix keys for l3 store. (#9045)
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -22,7 +22,10 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.hicache_storage import (
|
||||
HiCacheStorageConfig,
|
||||
HiCacheStorageExtraInfo,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
@@ -191,12 +194,14 @@ class StorageOperation:
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
hash_value: Optional[List[str]] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
):
|
||||
self.host_indices = host_indices
|
||||
self.token_ids = token_ids
|
||||
self.last_hash = last_hash
|
||||
self.completed_tokens = 0
|
||||
self.hash_value = hash_value if hash_value is not None else []
|
||||
self.prefix_keys = prefix_keys
|
||||
|
||||
self.id = StorageOperation.counter
|
||||
StorageOperation.counter += 1
|
||||
@@ -212,6 +217,7 @@ class PrefetchOperation(StorageOperation):
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
|
||||
@@ -219,7 +225,7 @@ class PrefetchOperation(StorageOperation):
|
||||
self._terminated_flag = False
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
super().__init__(host_indices, token_ids, last_hash)
|
||||
super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys)
|
||||
|
||||
def increment(self, num_tokens: int):
|
||||
with self._lock:
|
||||
@@ -550,12 +556,13 @@ class HiCacheController:
|
||||
host_indices: torch.Tensor,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
) -> PrefetchOperation:
|
||||
"""
|
||||
Prefetch KV caches from storage backend to host memory.
|
||||
"""
|
||||
operation = PrefetchOperation(
|
||||
request_id, host_indices, new_input_tokens, last_hash
|
||||
request_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
||||
)
|
||||
self.prefetch_queue.put(operation)
|
||||
return operation
|
||||
@@ -571,8 +578,12 @@ class HiCacheController:
|
||||
for page in pages:
|
||||
self.host_mem_release_queue.put(page)
|
||||
|
||||
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
||||
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
||||
def _page_get_zero_copy(
|
||||
self, operation, hash_values, host_indices, extra_info=None
|
||||
):
|
||||
results = self.storage_backend.batch_get_v1(
|
||||
hash_values, host_indices, extra_info
|
||||
)
|
||||
inc = 0
|
||||
for i in range(len(hash_values)):
|
||||
if not results[i]:
|
||||
@@ -584,7 +595,7 @@ class HiCacheController:
|
||||
operation.increment(inc)
|
||||
|
||||
# todo: deprecate
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||
]
|
||||
@@ -608,6 +619,7 @@ class HiCacheController:
|
||||
|
||||
def _page_transfer(self, operation):
|
||||
# Transfer batch by batch
|
||||
prefix_keys = operation.prefix_keys
|
||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
@@ -615,7 +627,8 @@ class HiCacheController:
|
||||
]
|
||||
prev_completed_tokens = operation.completed_tokens
|
||||
# Get one batch token, and update the completed_tokens if succeed
|
||||
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
||||
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
||||
self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info)
|
||||
# Check termination
|
||||
if (
|
||||
operation.completed_tokens
|
||||
@@ -623,6 +636,10 @@ class HiCacheController:
|
||||
):
|
||||
operation.mark_terminate()
|
||||
break # Some operations fail or operation terminated by controller
|
||||
|
||||
if prefix_keys and len(prefix_keys) > 0:
|
||||
prefix_keys += batch_hashes
|
||||
|
||||
# release pre-allocated memory
|
||||
self.append_host_mem_release(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
@@ -656,6 +673,7 @@ class HiCacheController:
|
||||
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None
|
||||
|
||||
storage_query_count = 0
|
||||
hash_value = []
|
||||
@@ -673,11 +691,15 @@ class HiCacheController:
|
||||
batch_tokens[i : i + self.page_size], last_hash
|
||||
)
|
||||
batch_hashes.append(last_hash)
|
||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
||||
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
|
||||
hash_value.extend(batch_hashes[:hit_page_num])
|
||||
storage_query_count += hit_page_num * self.page_size
|
||||
if hit_page_num < len(batch_hashes):
|
||||
break
|
||||
if prefix_keys and len(prefix_keys) > 0:
|
||||
prefix_keys += batch_hashes
|
||||
|
||||
return hash_value, storage_query_count
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
@@ -734,28 +756,34 @@ class HiCacheController:
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
hash_value: Optional[List[str]] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Write KV caches from host memory to storage backend.
|
||||
"""
|
||||
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
||||
operation = StorageOperation(
|
||||
host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys
|
||||
)
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
# todo: deprecate
|
||||
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||
def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool:
|
||||
data = [
|
||||
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
||||
for i in range(len(hash_values))
|
||||
]
|
||||
return self.storage_backend.batch_set(hash_values, data)
|
||||
|
||||
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
||||
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
||||
def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool:
|
||||
return all(
|
||||
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
|
||||
)
|
||||
|
||||
# Backup batch by batch
|
||||
def _page_backup(self, operation):
|
||||
# Backup batch by batch
|
||||
prefix_keys = operation.prefix_keys
|
||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
@@ -763,12 +791,16 @@ class HiCacheController:
|
||||
]
|
||||
# Set one batch token, and record if success.
|
||||
# todo: allow partial success
|
||||
success = self.page_set_func(batch_hashes, batch_host_indices)
|
||||
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
||||
success = self.page_set_func(batch_hashes, batch_host_indices, extra_info)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Write page to storage: {len(batch_hashes)} pages failed."
|
||||
)
|
||||
break
|
||||
|
||||
if prefix_keys and len(prefix_keys) > 0:
|
||||
prefix_keys += batch_hashes
|
||||
operation.completed_tokens += self.page_size * len(batch_hashes)
|
||||
|
||||
def backup_thread_func(self):
|
||||
|
||||
@@ -1491,8 +1491,18 @@ class Scheduler(
|
||||
last_hash = req.last_host_node.get_last_hash_value()
|
||||
matched_len = len(req.prefix_indices) + req.host_hit_length
|
||||
new_input_tokens = req.fill_ids[matched_len:]
|
||||
|
||||
prefix_keys = (
|
||||
req.last_node.get_prefix_hash_values(req.last_node.parent)
|
||||
if self.tree_cache.hicache_storage_pass_prefix_keys
|
||||
else None
|
||||
)
|
||||
self.tree_cache.prefetch_from_storage(
|
||||
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||
req.rid,
|
||||
req.last_host_node,
|
||||
new_input_tokens,
|
||||
last_hash,
|
||||
prefix_keys,
|
||||
)
|
||||
|
||||
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
||||
|
||||
@@ -36,6 +36,7 @@ class HiCacheStorageConfig:
|
||||
|
||||
@dataclass
|
||||
class HiCacheStorageExtraInfo:
|
||||
prefix_keys: Optional[List[str]] = (None,)
|
||||
extra_info: Optional[dict] = None
|
||||
|
||||
|
||||
@@ -139,7 +140,9 @@ class HiCacheStorage(ABC):
|
||||
pass
|
||||
|
||||
# TODO: Use a finer-grained return type (e.g., List[bool])
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
def batch_exists(
|
||||
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
|
||||
) -> int:
|
||||
"""
|
||||
Check if the keys exist in the storage.
|
||||
return the number of consecutive existing keys from the start.
|
||||
|
||||
@@ -84,12 +84,14 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold,
|
||||
prefetch_timeout_base,
|
||||
prefetch_timeout_per_ki_token,
|
||||
hicache_storage_pass_prefix_keys,
|
||||
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
||||
self.prefetch_threshold = prefetch_threshold
|
||||
self.prefetch_timeout_base = prefetch_timeout_base
|
||||
self.prefetch_timeout_per_page = (
|
||||
page_size / 1024 * prefetch_timeout_per_ki_token
|
||||
)
|
||||
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
|
||||
# TODO: support more timeout check functions
|
||||
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||
@@ -149,7 +151,7 @@ class HiRadixCache(RadixCache):
|
||||
storage_backend_extra_config: JSON string containing extra configuration
|
||||
|
||||
Returns:
|
||||
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
||||
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
|
||||
"""
|
||||
# Parse extra config JSON if provided
|
||||
extra_config = {}
|
||||
@@ -165,6 +167,9 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_timeout_per_ki_token = extra_config.pop(
|
||||
"prefetch_timeout_per_ki_token", 0.25
|
||||
) # seconds per 1024 tokens
|
||||
hicache_storage_pass_prefix_keys = extra_config.pop(
|
||||
"hicache_storage_pass_prefix_keys", False
|
||||
)
|
||||
|
||||
if not isinstance(prefetch_threshold, int):
|
||||
raise ValueError(
|
||||
@@ -184,6 +189,7 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold,
|
||||
float(prefetch_timeout_base),
|
||||
float(prefetch_timeout_per_ki_token),
|
||||
hicache_storage_pass_prefix_keys,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -245,8 +251,14 @@ class HiRadixCache(RadixCache):
|
||||
return len(host_indices)
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
prefix_keys = (
|
||||
node.get_prefix_hash_values(node.parent)
|
||||
if self.hicache_storage_pass_prefix_keys
|
||||
else None
|
||||
)
|
||||
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.hash_value
|
||||
node.host_value, node.key, node.hash_value, prefix_keys
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
@@ -700,6 +712,7 @@ class HiRadixCache(RadixCache):
|
||||
last_host_node: TreeNode,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
):
|
||||
# align the number of fetching tokens to the page size
|
||||
prefetch_length = len(new_input_tokens) - (
|
||||
@@ -723,7 +736,7 @@ class HiRadixCache(RadixCache):
|
||||
# no sufficient host memory for prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
||||
)
|
||||
self.ongoing_prefetch[req_id] = (
|
||||
last_host_node,
|
||||
|
||||
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -114,6 +114,13 @@ class TreeNode:
|
||||
return None
|
||||
return self.hash_value[-1]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_prefix_hash_values(self, node: TreeNode) -> List[str]:
|
||||
if node is None or node.hash_value is None:
|
||||
return []
|
||||
|
||||
return node.get_prefix_hash_values(node.parent) + node.hash_value
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
@@ -13,7 +13,11 @@ from aibrix_kvcache import (
|
||||
)
|
||||
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.hicache_storage import (
|
||||
HiCacheStorage,
|
||||
HiCacheStorageConfig,
|
||||
HiCacheStorageExtraInfo,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -140,7 +144,9 @@ class AibrixKVCacheStorage(HiCacheStorage):
|
||||
) -> bool:
|
||||
return self.batch_set([key], [value], [target_location], [target_size])
|
||||
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
def batch_exists(
|
||||
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
|
||||
) -> int:
|
||||
block_hash = BlockHashes(keys, self.page_size)
|
||||
status = self.kv_cache_manager.exists(None, block_hash)
|
||||
if status.is_ok():
|
||||
|
||||
@@ -408,7 +408,9 @@ class EICStorage(HiCacheStorage):
|
||||
exist_num = self.batch_exists([key])
|
||||
return exist_num == 1
|
||||
|
||||
def batch_exists(self, keys) -> int:
|
||||
def batch_exists(
|
||||
self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
|
||||
) -> int:
|
||||
if len(keys) == 0:
|
||||
return 0
|
||||
if self.use_zero_copy and not self.is_mla_model:
|
||||
|
||||
@@ -454,7 +454,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
result = self.metadata_client.exists(self.rank, [key])
|
||||
return result[0] if result else False
|
||||
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
def batch_exists(
|
||||
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
|
||||
) -> int:
|
||||
factor = 1
|
||||
if self.is_zero_copy and not self.is_mla_model:
|
||||
keys = self._get_mha_zero_copy_keys(keys)
|
||||
|
||||
@@ -399,7 +399,9 @@ class MooncakeStore(HiCacheStorage):
|
||||
exist_result = self._batch_exist([key])
|
||||
return exist_result[0] == 1
|
||||
|
||||
def batch_exists(self, keys) -> int:
|
||||
def batch_exists(
|
||||
self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
|
||||
) -> int:
|
||||
if self.is_mla_backend:
|
||||
query_keys = [f"{key}_k" for key in keys]
|
||||
key_multiplier = 1
|
||||
|
||||
Reference in New Issue
Block a user