HiCache Storage TP Refinement (#8307)

Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
This commit is contained in:
Zhiqiang Xie
2025-07-24 17:31:47 -07:00
committed by GitHub
parent 39fe1e880d
commit 145482f422
4 changed files with 103 additions and 24 deletions

View File

@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold
self.prefetch_threshold = 256
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
page_size,
self.tp_group,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
self.ongoing_backup[ack_id].hash_value = hash_value
self.ongoing_backup[ack_id].release_host()
ack_id, hash_value, completed_tokens = (
self.cache_controller.ack_backup_queue.get()
)
host_node = self.ongoing_backup[ack_id]
if completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(host_node.key, host_node, completed_tokens)
new_node.hash_value = hash_value
host_node.release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
min_completed_tokens,
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = min_completed_tokens.item()
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
new_input_tokens: List[int],
last_hash: Optional[str] = None,
):
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
# align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
len(new_input_tokens) % self.page_size
)
new_input_tokens = new_input_tokens[:prefetch_length]
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
return
last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
self.evict_host(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(
len(new_input_tokens)
)
self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory to prefetch