feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

This commit is contained in:
harrisonlimh
2025-09-16 17:10:10 -07:00
committed by GitHub
parent f949ad5794
commit 14fdd52740
16 changed files with 822 additions and 71 deletions

View File

@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
policy="invalid",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
policy = SchedulePolicy(
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=disabled_tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
]
policy = SchedulePolicy(
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
if __name__ == "__main__":
unittest.main()