Upstreaming hicache bug fixes (#7267)
This commit is contained in:
@@ -565,6 +565,10 @@ class Scheduler(
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
)
|
||||
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
@@ -1514,8 +1518,13 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
# bypass prefix_computed if enable_hierarchical_cache
|
||||
req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache,
|
||||
(
|
||||
None
|
||||
if (prefix_computed and not self.enable_hierarchical_cache)
|
||||
else self.tree_cache
|
||||
),
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
|
||||
@@ -1548,9 +1557,6 @@ class Scheduler(
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache.ready_to_load_cache()
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
self.chunked_req = adder.new_chunked_req
|
||||
@@ -1574,6 +1580,10 @@ class Scheduler(
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
chunked_req=self.chunked_req,
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
||||
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
|
||||
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
@@ -1649,6 +1659,11 @@ class Scheduler(
|
||||
if self.is_generation:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.tp_worker.set_hicache_consumer(
|
||||
model_worker_batch.hicache_consumer_index
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
|
||||
Reference in New Issue
Block a user