Refactor SchedulePolicy to improve code organization (#2571)

This commit is contained in:
libra
2025-01-04 00:05:16 +08:00
committed by GitHub
parent f5d0865b25
commit bdb3929dbb
2 changed files with 214 additions and 93 deletions

View File

@@ -18,7 +18,7 @@ import random
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set, Union
import torch
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
)
class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"
class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""
self.policy = policy
LPM = "lpm" # longest prefix match
DFS_WEIGHT = "dfs-weight" # depth-first search weighting
class CacheAgnosticPolicy(Enum):
"""Scheduling policies that are not aware of the tree cache."""
FCFS = "fcfs" # first come first serve
LOF = "lof" # longest output first
RANDOM = "random"
class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(self, policy: str, tree_cache: BasePrefixCache):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
# It is used to find the matching prefix for in-batch prefix caching.
@@ -64,110 +77,166 @@ class SchedulePolicy:
req_to_token_pool=None, token_to_kv_pool=None, disable=False
)
def calc_priority(self, waiting_queue: List[Req]):
if len(waiting_queue) > 128 and self.policy == "lpm":
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy = "fcfs"
else:
policy = self.policy
def calc_priority(self, waiting_queue: List[Req]) -> bool:
policy = self._determine_active_policy(waiting_queue)
# Compute matched prefix length
prefix_computed = False
if policy == "lpm" or policy == "dfs-weight":
# rid to deprioritize in the current run for in-batch prefix caching.
temporary_deprioritized = set()
self.waiting_queue_radix_tree.reset()
for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized.add(r.rid)
else:
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)
if isinstance(policy, CacheAwarePolicy):
prefix_computed = True
if policy == "lpm":
# Longest Prefix Match
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
)
elif policy == "fcfs":
# first come first serve
pass
elif policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
elif policy == "random":
random.shuffle(waiting_queue)
elif policy == "dfs-weight":
# Experimental policy based on custom weights
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight)
waiting_queue.clear()
self.get_dfs_priority(
self.tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy
def _validate_and_adjust_policy(
self, policy: str, tree_cache: BasePrefixCache
) -> Policy:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try:
policy_enum = CacheAwarePolicy(policy)
if tree_cache.disable:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS
return policy_enum
except ValueError:
try:
return CacheAgnosticPolicy(policy)
except ValueError:
raise ValueError(f"Unknown schedule_policy: {policy=}")
def _compute_prefix_matches(
self, waiting_queue: List[Req], policy: CacheAwarePolicy
) -> Set[int]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized: Set[int] = set()
self.waiting_queue_radix_tree.reset()
for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized.add(r.rid)
else:
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)
return temporary_deprioritized
@staticmethod
def _sort_by_longest_prefix(
waiting_queue: List[Req], temporary_deprioritized: Set[int]
) -> None:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
)
)
@staticmethod
def _sort_by_dfs_weight(
waiting_queue: List[Req], tree_cache: BasePrefixCache
) -> None:
"""Sorts the waiting queue based on a depth-first search weighting."""
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
waiting_queue.clear()
SchedulePolicy._get_dfs_priority(
tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
@staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
@staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)
SchedulePolicy._calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
def get_dfs_priority(
self,
@staticmethod
def _get_dfs_priority(
cur_node: TreeNode,
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
):
) -> None:
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
SchedulePolicy._get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q
)
q.extend(last_node_to_reqs[cur_node])