Upstreaming hicache bug fixes (#7267)

This commit is contained in:
Zhiqiang Xie
2025-06-17 17:44:57 -07:00
committed by GitHub
parent c26d7349d3
commit e56685ac1b
7 changed files with 76 additions and 24 deletions

View File

@@ -565,6 +565,10 @@ class Scheduler(
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
@@ -1514,8 +1518,13 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
# bypass prefix_computed if enable_hierarchical_cache
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
(
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache,
)
@@ -1548,9 +1557,6 @@ class Scheduler(
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.ready_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
@@ -1574,6 +1580,10 @@ class Scheduler(
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
@@ -1649,6 +1659,11 @@ class Scheduler(
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)