Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -265,12 +265,10 @@ class Scheduler:
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.init_memory_pool_and_cache()
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.staging_reqs = {}
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
@@ -308,7 +306,9 @@ class Scheduler:
|
||||
self.grammar_backend = None
|
||||
|
||||
# Init schedule policy and new token estimation
|
||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||
self.policy = SchedulePolicy(
|
||||
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
|
||||
)
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
), "Invalid schedule_conservativeness"
|
||||
@@ -431,6 +431,7 @@ class Scheduler:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
@@ -1005,6 +1006,11 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
return None
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
# check for completion of hierarchical cache activities to release memory
|
||||
self.tree_cache.writing_check()
|
||||
self.tree_cache.loading_check()
|
||||
|
||||
# Get priority queue
|
||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||
|
||||
@@ -1048,32 +1054,14 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
|
||||
if self.enable_hierarchical_cache and req.last_node is not None:
|
||||
if req.last_node.evicted:
|
||||
# loading KV cache for the request
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node,
|
||||
req.prefix_indices,
|
||||
adder.rem_total_tokens,
|
||||
)
|
||||
if req.last_node.loading:
|
||||
# to prevent frequent cache invalidation
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self.staging_reqs[req.rid] = req.last_node
|
||||
continue
|
||||
elif req.last_node.loading:
|
||||
if not self.tree_cache.loading_complete(req.last_node):
|
||||
continue
|
||||
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
del self.staging_reqs[req.rid]
|
||||
|
||||
res = adder.add_one_req(req, self.chunked_req)
|
||||
res = adder.add_one_req(
|
||||
req, self.chunked_req, self.enable_hierarchical_cache
|
||||
)
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -1094,6 +1082,9 @@ class Scheduler:
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache.read_to_load_cache()
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
self.chunked_req = adder.new_chunked_req
|
||||
|
||||
Reference in New Issue
Block a user