Hierarchical Caching Refactoring and Fixing TP issue (#4082)

This commit is contained in:
Zhiqiang Xie
2025-03-12 11:22:35 -07:00
committed by GitHub
parent 01090e8ac3
commit 10b544ae9b
6 changed files with 194 additions and 56 deletions

View File

@@ -265,12 +265,10 @@ class Scheduler:
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
@@ -308,7 +306,9 @@ class Scheduler:
self.grammar_backend = None
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
@@ -431,6 +431,7 @@ class Scheduler:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
)
else:
self.tree_cache = RadixCache(
@@ -1005,6 +1006,11 @@ class Scheduler:
self.batch_is_full = True
return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
@@ -1048,32 +1054,14 @@ class Scheduler:
self.batch_is_full = True
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)
if self.enable_hierarchical_cache and req.last_node is not None:
if req.last_node.evicted:
# loading KV cache for the request
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]
res = adder.add_one_req(req, self.chunked_req)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
@@ -1094,6 +1082,9 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req