From bdb3929dbb8f14d22c4a27ee2d8840751752658c Mon Sep 17 00:00:00 2001 From: libra Date: Sat, 4 Jan 2025 00:05:16 +0800 Subject: [PATCH] Refactor SchedulePolicy to improve code organization (#2571) --- python/sglang/srt/managers/schedule_policy.py | 255 +++++++++++------- test/srt/test_schedule_policy.py | 52 ++++ 2 files changed, 214 insertions(+), 93 deletions(-) create mode 100644 test/srt/test_schedule_policy.py diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index e92b1ddce..d2083d092 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -18,7 +18,7 @@ import random from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set, Union import torch @@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( ) -class SchedulePolicy: - def __init__(self, policy: str, tree_cache: BasePrefixCache): - if tree_cache.disable and policy in ["lpm", "dfs-weight"]: - # LPM and DFS-weight is meaningless when the tree cache is disabled. - policy = "fcfs" +class CacheAwarePolicy(Enum): + """Scheduling policies that are aware of the tree cache.""" - self.policy = policy + LPM = "lpm" # longest prefix match + DFS_WEIGHT = "dfs-weight" # depth-first search weighting + + +class CacheAgnosticPolicy(Enum): + """Scheduling policies that are not aware of the tree cache.""" + + FCFS = "fcfs" # first come first serve + LOF = "lof" # longest output first + RANDOM = "random" + + +class SchedulePolicy: + Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] + + def __init__(self, policy: str, tree_cache: BasePrefixCache): + self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache # It is used to find the matching prefix for in-batch prefix caching. @@ -64,110 +77,166 @@ class SchedulePolicy: req_to_token_pool=None, token_to_kv_pool=None, disable=False ) - def calc_priority(self, waiting_queue: List[Req]): - if len(waiting_queue) > 128 and self.policy == "lpm": - # Turn off the expensive prefix matching and sorting when the #queue is large. - policy = "fcfs" - else: - policy = self.policy + def calc_priority(self, waiting_queue: List[Req]) -> bool: + policy = self._determine_active_policy(waiting_queue) - # Compute matched prefix length prefix_computed = False - if policy == "lpm" or policy == "dfs-weight": - # rid to deprioritize in the current run for in-batch prefix caching. - temporary_deprioritized = set() - self.waiting_queue_radix_tree.reset() - - for r in waiting_queue: - prefix_ids = r.adjust_max_prefix_ids() - - # NOTE: the prefix_indices must always be aligned with last_node - r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=prefix_ids - ) - - # NOTE(sang): This logic is for in-batch prefix caching; - # If there are more than 1 request that have small matching prefix from - # existing cache, but all those requests share the same prefix, we prefer - # to schedule only one of them so that we can increase the cache hit rate. - # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small - # threshold means we cannot use in-batch prefix caching for short prefixes. - # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). - if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: - in_batch_matching_prefixes, _ = ( - self.waiting_queue_radix_tree.match_prefix( - rid=r.rid, key=prefix_ids - ) - ) - if ( - len(in_batch_matching_prefixes) - >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD - ): - temporary_deprioritized.add(r.rid) - else: - # Insert with a dummy key - self.waiting_queue_radix_tree.insert( - prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) - ) - + if isinstance(policy, CacheAwarePolicy): prefix_computed = True - - if policy == "lpm": - # Longest Prefix Match - waiting_queue.sort( - key=lambda r: ( - -len(r.prefix_indices) - if r.rid not in temporary_deprioritized - else float("inf") + temporary_deprioritized = self._compute_prefix_matches( + waiting_queue, policy + ) + if policy == CacheAwarePolicy.LPM: + SchedulePolicy._sort_by_longest_prefix( + waiting_queue, temporary_deprioritized ) - ) - elif policy == "fcfs": - # first come first serve - pass - elif policy == "lof": - # longest output first - waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) - elif policy == "random": - random.shuffle(waiting_queue) - elif policy == "dfs-weight": - # Experimental policy based on custom weights - last_node_to_reqs = defaultdict(list) - for req in waiting_queue: - last_node_to_reqs[req.last_node].append(req) - - node_to_weight = defaultdict(int) - for node in last_node_to_reqs: - node_to_weight[node] = len(last_node_to_reqs[node]) - self.calc_weight(self.tree_cache.root_node, node_to_weight) - - waiting_queue.clear() - self.get_dfs_priority( - self.tree_cache.root_node, - node_to_weight, - last_node_to_reqs, - waiting_queue, - ) + elif policy == CacheAwarePolicy.DFS_WEIGHT: + SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache) + else: + raise ValueError(f"Unknown CacheAware Policy: {policy=}") else: - raise ValueError(f"Unknown schedule_policy: {policy=}") + if policy == CacheAgnosticPolicy.FCFS: + pass + elif policy == CacheAgnosticPolicy.LOF: + SchedulePolicy._sort_by_longest_output(waiting_queue) + elif policy == CacheAgnosticPolicy.RANDOM: + SchedulePolicy._sort_randomly(waiting_queue) + else: + raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}") return prefix_computed - def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): + def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: + if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM: + # Turn off the expensive prefix matching and sorting when the #queue is large. + return CacheAgnosticPolicy.FCFS + return self.policy + + def _validate_and_adjust_policy( + self, policy: str, tree_cache: BasePrefixCache + ) -> Policy: + """ + Validates the policy and adjusts it if necessary based on tree cache settings. + """ + try: + policy_enum = CacheAwarePolicy(policy) + if tree_cache.disable: + # If tree_cache is disabled, using CacheAgnosticPolicy policy + return CacheAgnosticPolicy.FCFS + return policy_enum + except ValueError: + try: + return CacheAgnosticPolicy(policy) + except ValueError: + raise ValueError(f"Unknown schedule_policy: {policy=}") + + def _compute_prefix_matches( + self, waiting_queue: List[Req], policy: CacheAwarePolicy + ) -> Set[int]: + """ + Computes and caches the matching prefixes for requests in the waiting queue, + and handles in-batch prefix caching logic. + """ + temporary_deprioritized: Set[int] = set() + self.waiting_queue_radix_tree.reset() + + for r in waiting_queue: + prefix_ids = r.adjust_max_prefix_ids() + + # NOTE: the prefix_indices must always be aligned with last_node + r.prefix_indices, r.last_node = self.tree_cache.match_prefix( + rid=r.rid, key=prefix_ids + ) + + # NOTE(sang): This logic is for in-batch prefix caching; + # If there are more than 1 request that have small matching prefix from + # existing cache, but all those requests share the same prefix, we prefer + # to schedule only one of them so that we can increase the cache hit rate. + # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small + # threshold means we cannot use in-batch prefix caching for short prefixes. + # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). + if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: + in_batch_matching_prefixes, _ = ( + self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, key=prefix_ids + ) + ) + if ( + len(in_batch_matching_prefixes) + >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD + ): + temporary_deprioritized.add(r.rid) + else: + # Insert with a dummy key + self.waiting_queue_radix_tree.insert( + prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) + ) + return temporary_deprioritized + + @staticmethod + def _sort_by_longest_prefix( + waiting_queue: List[Req], temporary_deprioritized: Set[int] + ) -> None: + """Sorts the waiting queue based on the longest prefix match.""" + waiting_queue.sort( + key=lambda r: ( + -len(r.prefix_indices) + if r.rid not in temporary_deprioritized + else float("inf") + ) + ) + + @staticmethod + def _sort_by_dfs_weight( + waiting_queue: List[Req], tree_cache: BasePrefixCache + ) -> None: + """Sorts the waiting queue based on a depth-first search weighting.""" + last_node_to_reqs = defaultdict(list) + for req in waiting_queue: + last_node_to_reqs[req.last_node].append(req) + + node_to_weight = defaultdict(int) + for node in last_node_to_reqs: + node_to_weight[node] = len(last_node_to_reqs[node]) + SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight) + + waiting_queue.clear() + SchedulePolicy._get_dfs_priority( + tree_cache.root_node, + node_to_weight, + last_node_to_reqs, + waiting_queue, + ) + + @staticmethod + def _sort_by_longest_output(waiting_queue: List[Req]) -> None: + """Sorts the waiting queue based on the longest output (max_new_tokens).""" + waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + + @staticmethod + def _sort_randomly(waiting_queue: List[Req]) -> None: + """Shuffles the waiting queue randomly.""" + random.shuffle(waiting_queue) + + @staticmethod + def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: for child in cur_node.children.values(): - self.calc_weight(child, node_to_weight) + SchedulePolicy._calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] - def get_dfs_priority( - self, + @staticmethod + def _get_dfs_priority( cur_node: TreeNode, node_to_priority: Dict[TreeNode, int], last_node_to_reqs: Dict[TreeNode, List[Req]], q: List, - ): + ) -> None: childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: - self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) + SchedulePolicy._get_dfs_priority( + child, node_to_priority, last_node_to_reqs, q + ) q.extend(last_node_to_reqs[cur_node]) diff --git a/test/srt/test_schedule_policy.py b/test/srt/test_schedule_policy.py new file mode 100644 index 000000000..52c5b8289 --- /dev/null +++ b/test/srt/test_schedule_policy.py @@ -0,0 +1,52 @@ +import unittest + +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.managers.schedule_policy import ( + CacheAgnosticPolicy, + CacheAwarePolicy, + SchedulePolicy, +) +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.sampling.sampling_params import SamplingParams + + +class TestSchedulePolicy(unittest.TestCase): + + def setUp(self): + self.tree_cache = RadixCache(None, None, False) + + def test_init_with_cache_aware_policy(self): + policy = SchedulePolicy(policy="lpm", tree_cache=self.tree_cache) + self.assertEqual(policy.policy, CacheAwarePolicy.LPM) + + def test_init_with_cache_agnostic_policy(self): + policy = SchedulePolicy(policy="fcfs", tree_cache=self.tree_cache) + self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) + + def test_init_with_unknown_policy(self): + with self.assertRaises(ValueError): + SchedulePolicy(policy="invalid", tree_cache=self.tree_cache) + + def test_init_with_disabled_cache(self): + disabled_tree_cache = RadixCache(None, None, disable=True) + policy = SchedulePolicy(policy="lpm", tree_cache=disabled_tree_cache) + self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) + + def test_calc_priority_fcfs(self): + tree_cache = RadixCache(None, None, False) + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams()), + Req(3, "a b c", [1, 2, 3], SamplingParams()), + Req(2, "a", [1], SamplingParams()), + ] + + policy = SchedulePolicy(policy="fcfs", tree_cache=tree_cache) + policy.calc_priority(waiting_queue) + # Check if FCFS keeps the original order + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 3) + self.assertEqual(waiting_queue[2].rid, 2) + + +if __name__ == "__main__": + unittest.main()