feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

This commit is contained in:
harrisonlimh
2025-09-16 17:10:10 -07:00
committed by GitHub
parent f949ad5794
commit 14fdd52740
16 changed files with 822 additions and 71 deletions

View File

@@ -95,6 +95,7 @@ suites = {
TestFile("test_original_logprobs.py", 200),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_priority_scheduling.py", 100),
TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105),
TestFile("test_regex_constrained.py", 64),

View File

@@ -0,0 +1,339 @@
import asyncio
import os
import re
import unittest
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
CustomTestCase,
popen_launch_server,
send_concurrent_generate_requests_with_custom_params,
)
class TestPriorityScheduling(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 1
"1",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(1, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_request_ordering_validation(self):
"""Verify pending requests are ordered by priority and received timestamp."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first
{"priority": 1}, # third
{"priority": 1}, # fourth
{"priority": 2}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2]
def test_priority_scheduling_existing_requests_abortion_validation(self):
"""Verify lower priority requests are aborted when incoming requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 1,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 2}, # aborted by request 5
{"priority": 3}, # aborted by request 6
{"priority": 4}, # aborted by request 7
{"priority": 5}, # fourth
{"priority": 6}, # third
{"priority": 7}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4]
def test_priority_scheduling_incoming_request_rejection_validation(self):
"""Verify incoming requests are rejected when existing requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 7,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 6}, # second
{"priority": 5}, # third
{"priority": 4}, # fourth
{"priority": 3}, # rejected
{"priority": 2}, # rejected
{"priority": 1}, # rejected
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
(503, "The request queue is full."),
(503, "The request queue is full."),
(503, "The request queue is full."),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3]
def test_priority_scheduling_preemption_meeting_threshold_validation(self):
"""Verify running requests are preempted by requests with priorities meeting the preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first then preempted or pushed by later requests, and finishes last.
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # scheduled after the third request, and finishes second.
{
"priority": 20,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first.
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0]
def test_priority_scheduling_preemption_below_threshold_validation(self):
"""Verify running requests are not preempted by requests with priorities below preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
},
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
},
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1]
class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 2
"2",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(2, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_with_multiple_running_requests_preemption(self):
"""Verify preempting a subset of running requests is safe."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
}, # preempted by fourth request, then finishes third
{
"priority": 15,
"sampling_params": {"max_new_tokens": 10000},
}, # preempt the first request
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
_verify_genereate_responses(responses, expected_status_and_error_messages, [])
def _verify_genereate_responses(
responses: Tuple[int, Any, float],
expected_code_and_error_message: Tuple[int, Any],
e2e_latencies: List[Optional[float]],
):
"""
Verify generate response results are as expected based on status code and response json object content.
In addition, collects e2e latency info to verify scheduling and processing ordering.
"""
for got, expected in zip(responses, expected_code_and_error_message):
got_status, got_json = got
expected_status, expected_err_msg = expected
# Check status code is as expected
assert got_status == expected_status
# Check error message content or fields' existence based on status code
if got_status != 200:
assert got_json["object"] == "error"
assert got_json["message"] == expected_err_msg
else:
assert "object" not in got_json
assert "message" not in got_json
# Collect e2e latencies for scheduling validation
e2e_latencies.append(
got_json["meta_info"]["e2e_latency"] if got_status == 200 else None
)
def _verify_max_running_requests_and_max_queued_request_validation(
max_running_requests: int, max_queued_requests: int
):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
with open(STDERR_FILENAME) as lines:
for line in lines:
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
if rr_match:
assert int(rr_match.group(1)) <= max_running_requests
if qr_match:
assert int(qr_match.group(1)) <= max_queued_requests
if __name__ == "__main__":
unittest.main()

View File

@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase):
send_concurrent_generate_requests(self.base_url, num_requests=10)
)
assert 200 in status_codes
assert 503 in status_codes
assert all(status_code in [200, 503] for status_code in status_codes)
expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
assert status_codes == expected_status_codes
def test_max_running_requests_and_max_queued_request_validation(self):
"""Verify running request and queued request numbers based on server logs."""

View File

@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
policy="invalid",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
policy = SchedulePolicy(
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=disabled_tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
]
policy = SchedulePolicy(
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
if __name__ == "__main__":
unittest.main()