Hicache Storage Layer Prototype (#7704)

This commit is contained in:
Zhiqiang Xie
2025-07-18 00:20:19 -07:00
committed by GitHub
parent 7891bac16b
commit 9d33fcfb8e
9 changed files with 714 additions and 4 deletions

View File

@@ -25,6 +25,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
logger = logging.getLogger(__name__)
@@ -159,6 +161,57 @@ class TransferBuffer:
self.buffers.queue.clear()
class StorageOperation:
counter = 0
def __init__(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.host_indices = host_indices
self.token_ids = token_ids
self.last_hash = last_hash
self.completed_tokens = 0
self.hash_value = []
self.id = StorageOperation.counter
StorageOperation.counter += 1
def __lt__(self, other: "StorageOperation"):
return self.id < other.id
class PrefetchOperation(StorageOperation):
def __init__(
self,
request_id: str,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.request_id = request_id
self._done_flag = False
self._lock = threading.Lock()
super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int):
with self._lock:
if self._done_flag:
return
self.completed_tokens += num_tokens
def mark_done(self):
with self._lock:
self._done_flag = True
def is_done(self) -> bool:
return self._done_flag
class HiCacheController:
def __init__(
@@ -169,6 +222,8 @@ class HiCacheController:
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
io_backend: str = "",
storage_backend: Optional[str] = None,
prefetch_threshold: int = 256,
):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -186,6 +241,19 @@ class HiCacheController:
else:
self.io_backend = io_backend
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
@@ -218,9 +286,26 @@ class HiCacheController:
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_queue = Queue()
self.backup_queue = Queue()
self.prefetch_revoke_queue = Queue()
self.ack_backup_queue = Queue()
self.prefetch_thread.start()
self.backup_thread.start()
def reset(self):
self.stop_event.set()
self.write_thread.join()
@@ -232,6 +317,13 @@ class HiCacheController:
self.load_buffer.clear()
self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear()
if self.enable_storage:
self.prefetch_thread.join()
self.backup_thread.join()
self.prefetch_queue.queue.clear()
self.backup_queue.queue.clear()
self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
@@ -243,6 +335,16 @@ class HiCacheController:
self.write_thread.start()
self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_thread.start()
self.backup_thread.start()
def write(
self,
device_indices: torch.Tensor,
@@ -383,3 +485,142 @@ class HiCacheController:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def prefetch(
self,
request_id: str,
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation = PrefetchOperation(
request_id, host_indices, new_input_tokens, last_hash
)
self.prefetch_queue.put(operation)
return operation
def terminate_prefetch(self, operation):
operation.mark_done()
return operation.completed_tokens, operation.hash_value
def prefetch_io_aux_func(self):
"""
Auxiliary function conducting IO operations for prefetching.
"""
while not self.stop_event.is_set():
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
for h in operation.hash_value:
page_data = self.storage_backend.get(h)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
)
break
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens],
page_data,
)
operation.increment(self.page_size)
if operation.is_done():
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break
except Empty:
continue
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
"""
self.prefetch_buffer = Queue()
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
aux_thread.start()
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
try:
operation = self.prefetch_queue.get(block=True, timeout=1)
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.storage_backend.exists(last_hash):
storage_hit_count += self.page_size
hash_value.append(last_hash)
remaining_tokens -= self.page_size
else:
break
if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
else:
operation.hash_value = hash_value
logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
)
self.prefetch_buffer.put(operation)
except Empty:
continue
def write_storage(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Write KV caches from host memory to storage backend.
"""
operation = StorageOperation(host_indices, token_ids, last_hash)
self.backup_queue.put(operation)
return operation.id
def backup_thread_func(self):
"""
Manage backup operations from host memory to storage backend.
"""
while not self.stop_event.is_set():
try:
operation = self.backup_queue.get(block=True, timeout=1)
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_backup = operation.token_ids
for i in range(0, len(tokens_to_backup), self.page_size):
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
# todo, handle failures in storage backend
self.storage_backend.set(
last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
)
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)
self.ack_backup_queue.put((operation.id, operation.hash_value))
except Empty:
continue