Support page size > 1 (#4356)

This commit is contained in:
Lianmin Zheng
2025-03-12 22:22:39 -07:00
committed by GitHub
parent 2f6bacee03
commit c76040e31b
23 changed files with 877 additions and 284 deletions

View File

@@ -77,7 +77,7 @@ class SchedulePolicy:
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
enable_hierarchical_cache: bool,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
@@ -85,10 +85,17 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
return
policy = self._determine_active_policy(waiting_queue)
prefix_computed = False
@@ -118,7 +125,7 @@ class SchedulePolicy:
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy
@@ -442,7 +449,7 @@ class PrefillAdder:
def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min(