HiCache Storage TP Refinement (#8307)

Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
This commit is contained in:
Zhiqiang Xie
2025-07-24 17:31:47 -07:00
committed by GitHub
parent 39fe1e880d
commit 145482f422
4 changed files with 103 additions and 24 deletions

View File

@@ -219,6 +219,7 @@ class HiCacheController:
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
io_backend: str = "",
@@ -244,11 +245,17 @@ class HiCacheController:
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
@@ -568,13 +575,32 @@ class HiCacheController:
else:
break
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
)
torch.distributed.all_reduce(
storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
storage_hit_count = storage_hit_count_tensor.item()
if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
else:
operation.hash_value = hash_value
logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
)
else:
operation.hash_value = hash_value[
: (storage_hit_count // self.page_size)
]
# free the pre-allocated memory for pages that are not hit
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
operation.host_indices = operation.host_indices[:storage_hit_count]
logger.debug(
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
)
self.prefetch_buffer.put(operation)
@@ -611,17 +637,37 @@ class HiCacheController:
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
# todo, handle failures in storage backend
self.storage_backend.set(
success = self.storage_backend.set(
last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
)
if not success:
logger.warning(f"Failed to write page {last_hash} to storage.")
break
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)
self.ack_backup_queue.put((operation.id, operation.hash_value))
min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
self.ack_backup_queue.put(
(
operation.id,
operation.hash_value[: min_completed_tokens // self.page_size],
min_completed_tokens,
)
)
except Empty:
continue