[HiCache] Storage Refactoring (#9797)

Co-authored-by: pansicheng <27603155+pansicheng@users.noreply.github.com>
This commit is contained in:
Zhiqiang Xie
2025-08-31 07:58:21 -07:00
committed by GitHub
parent a391f73adc
commit 8b6966d020
3 changed files with 114 additions and 154 deletions

View File

@@ -250,26 +250,21 @@ class HiCacheController:
self.write_policy = write_policy
self.page_size = page_size
self.io_backend = io_backend
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import get_hash_str
self.get_hash_str = get_hash_str
self.storage_config = self._generate_storage_config(
model_name, storage_backend_extra_config
)
# In MLA backend, only one rank needs to backup the KV cache
# for MLA models, only one rank needs to backup the KV cache
self.backup_skip = (
self.storage_config.is_mla_model
# todo: for load balancing, decide which rank to backup the KV cache by hash value
# todo: load balancing
and self.storage_config.tp_rank != 0
# todo: support other storage backends
and self.storage_backend_type in ["file", "mooncake"]
)
if storage_backend == "file":
@@ -309,12 +304,15 @@ class HiCacheController:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
)
# granularity of batch storage IO operations, in number of pages
self.storage_batch_size = 128
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0
@@ -325,12 +323,6 @@ class HiCacheController:
self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.prefetch_io_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -380,6 +372,7 @@ class HiCacheController:
self.prefetch_revoke_queue = Queue()
self.ack_backup_queue = Queue()
self.host_mem_release_queue = Queue()
self.prefetch_thread.start()
self.backup_thread.start()
@@ -618,7 +611,11 @@ class HiCacheController:
operation.mark_done()
return operation.completed_tokens, operation.hash_value
# zero copy
def append_host_mem_release(self, host_indices: torch.Tensor):
chunks = host_indices.split(self.mem_pool_host.page_size)
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
@@ -631,7 +628,6 @@ class HiCacheController:
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
# zero copy
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
@@ -650,9 +646,7 @@ class HiCacheController:
if get_result != 0:
operation.increment(get_result * self.page_size)
# non-zero copy
def _generic_page_get(self, operation, hash_values, host_indices):
# todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
hash_values
)
@@ -675,22 +669,19 @@ class HiCacheController:
def _page_transfer(self, operation):
# Select the get function and batch size
if self.is_mooncake_backend():
if self.storage_backend_type == "mooncake":
get_func = self._mooncake_page_get
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
get_func = self._3fs_zero_copy_page_get
elif self.mem_pool_host.layout == "layer_first":
get_func = self._generic_page_get
batch_size = 128
elif (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
):
get_func = self._3fs_zero_copy_page_get
else:
get_func = self._generic_page_get
batch_size = 8
# Transfer batch by batch
for i in range(0, len(operation.hash_value), batch_size):
batch_hashes = operation.hash_value[i : i + batch_size]
for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size
]
@@ -704,10 +695,9 @@ class HiCacheController:
):
break # Some operations fail or operation terminated by controller
# release pre-allocated memory
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
def is_mooncake_backend(self):
return self.storage_backend_type == "mooncake"
self.append_host_mem_release(
operation.host_indices[operation.completed_tokens :]
)
def prefetch_io_aux_func(self):
"""
@@ -717,47 +707,49 @@ class HiCacheController:
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
self._page_transfer(operation)
if self.tp_world_size > 1:
# to ensure all TP workers release the host memory at the same time
torch.distributed.barrier(group=self.prefetch_io_tp_group)
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
self.append_host_mem_release(
operation.host_indices[operation.completed_tokens :]
)
except Empty:
continue
def prefetch_rate_limit_check(self) -> bool:
def prefetch_rate_limited(self) -> bool:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False
return True
# todo: more sophisticated rate limiting based on storage backend performance
return True
return False
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_query_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_query_count : storage_query_count + self.page_size
],
last_hash,
for start in range(
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
):
end = min(
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
)
hash_value.append(last_hash)
storage_query_count += self.page_size
remaining_tokens -= self.page_size
# deferring to batch exists
hit_page_num = self.storage_backend.batch_exists(hash_value)
return hash_value[:hit_page_num], hit_page_num * self.page_size
batch_tokens = tokens_to_fetch[start:end]
batch_hashes = []
for i in range(0, len(batch_tokens), self.page_size):
last_hash = self.get_hash_str(
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
break
return hash_value, storage_query_count
def prefetch_thread_func(self):
"""
@@ -772,13 +764,7 @@ class HiCacheController:
if operation is None:
continue
if (
operation.host_indices is not None
) and self.prefetch_rate_limit_check():
hash_value, storage_hit_count = self._generic_storage_hit_query(
operation
)
hash_value, storage_hit_count = self._storage_hit_query(operation)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
@@ -793,8 +779,7 @@ class HiCacheController:
if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
if operation.host_indices is not None:
self.mem_pool_host.free(operation.host_indices)
self.append_host_mem_release(operation.host_indices)
logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
)
@@ -803,7 +788,9 @@ class HiCacheController:
: (storage_hit_count // self.page_size)
]
# free the pre-allocated memory for pages that are not hit
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
self.append_host_mem_release(
operation.host_indices[storage_hit_count:]
)
operation.host_indices = operation.host_indices[:storage_hit_count]
logger.debug(
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
@@ -858,21 +845,18 @@ class HiCacheController:
# Backup batch by batch
def _page_backup(self, operation):
# Select the set function and batch size
if self.is_mooncake_backend():
if self.storage_backend_type == "mooncake":
backup_set_func = self._mooncake_page_set
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
backup_set_func = self._3fs_zero_copy_page_set
elif self.mem_pool_host.layout == "layer_first":
backup_set_func = self._generic_page_set
batch_size = 128
elif (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
):
backup_set_func = self._3fs_zero_copy_page_set
else:
backup_set_func = self._generic_page_set
batch_size = 8
# Backup batch by batch
for i in range(0, len(operation.hash_value), batch_size):
batch_hashes = operation.hash_value[i : i + batch_size]
for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size
]
@@ -898,27 +882,7 @@ class HiCacheController:
if not self.backup_skip:
self._page_backup(operation)
min_completed_tokens = operation.completed_tokens
else:
min_completed_tokens = len(operation.token_ids)
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.backup_tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
self.ack_backup_queue.put(
(
operation.id,
min_completed_tokens,
)
)
self.ack_backup_queue.put(operation.id)
except Empty:
continue