feat: add priority based scheduling with priority based request acceptance and preemption (#8746)
This commit is contained in:
@@ -95,6 +95,7 @@ suites = {
|
||||
TestFile("test_original_logprobs.py", 200),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
TestFile("test_priority_scheduling.py", 100),
|
||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||
TestFile("test_radix_attention.py", 105),
|
||||
TestFile("test_regex_constrained.py", 64),
|
||||
|
||||
339
test/srt/test_priority_scheduling.py
Normal file
339
test/srt/test_priority_scheduling.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
STDERR_FILENAME,
|
||||
STDOUT_FILENAME,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
send_concurrent_generate_requests_with_custom_params,
|
||||
)
|
||||
|
||||
|
||||
class TestPriorityScheduling(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||
cls.stderr = open(STDERR_FILENAME, "w")
|
||||
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--max-running-requests", # Enforce max request concurrency is 1
|
||||
"1",
|
||||
"--max-queued-requests", # Enforce max queued request number is 3
|
||||
"3",
|
||||
"--enable-priority-scheduling", # Enable priority scheduling
|
||||
),
|
||||
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
_verify_max_running_requests_and_max_queued_request_validation(1, 3)
|
||||
cls.stdout.close()
|
||||
cls.stderr.close()
|
||||
os.remove(STDOUT_FILENAME)
|
||||
os.remove(STDERR_FILENAME)
|
||||
|
||||
def test_priority_scheduling_request_ordering_validation(self):
|
||||
"""Verify pending requests are ordered by priority and received timestamp."""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 0,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # starts being processed first
|
||||
{"priority": 1}, # third
|
||||
{"priority": 1}, # fourth
|
||||
{"priority": 2}, # second
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
]
|
||||
|
||||
e2e_latencies = []
|
||||
_verify_genereate_responses(
|
||||
responses, expected_status_and_error_messages, e2e_latencies
|
||||
)
|
||||
assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2]
|
||||
|
||||
def test_priority_scheduling_existing_requests_abortion_validation(self):
|
||||
"""Verify lower priority requests are aborted when incoming requests have higher priority"""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 1,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # starts being processed first and holds the running queue capacity
|
||||
{"priority": 2}, # aborted by request 5
|
||||
{"priority": 3}, # aborted by request 6
|
||||
{"priority": 4}, # aborted by request 7
|
||||
{"priority": 5}, # fourth
|
||||
{"priority": 6}, # third
|
||||
{"priority": 7}, # second
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(503, "The request is aborted by a higher priority request."),
|
||||
(503, "The request is aborted by a higher priority request."),
|
||||
(503, "The request is aborted by a higher priority request."),
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
]
|
||||
|
||||
e2e_latencies = []
|
||||
_verify_genereate_responses(
|
||||
responses, expected_status_and_error_messages, e2e_latencies
|
||||
)
|
||||
assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4]
|
||||
|
||||
def test_priority_scheduling_incoming_request_rejection_validation(self):
|
||||
"""Verify incoming requests are rejected when existing requests have higher priority"""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 7,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # starts being processed first and holds the running queue capacity
|
||||
{"priority": 6}, # second
|
||||
{"priority": 5}, # third
|
||||
{"priority": 4}, # fourth
|
||||
{"priority": 3}, # rejected
|
||||
{"priority": 2}, # rejected
|
||||
{"priority": 1}, # rejected
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
(503, "The request queue is full."),
|
||||
(503, "The request queue is full."),
|
||||
(503, "The request queue is full."),
|
||||
]
|
||||
|
||||
e2e_latencies = []
|
||||
_verify_genereate_responses(
|
||||
responses, expected_status_and_error_messages, e2e_latencies
|
||||
)
|
||||
assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3]
|
||||
|
||||
def test_priority_scheduling_preemption_meeting_threshold_validation(self):
|
||||
"""Verify running requests are preempted by requests with priorities meeting the preemption threshold"""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 0,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # starts being processed first then preempted or pushed by later requests, and finishes last.
|
||||
{
|
||||
"priority": 10,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # scheduled after the third request, and finishes second.
|
||||
{
|
||||
"priority": 20,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # finishes first.
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
]
|
||||
|
||||
e2e_latencies = []
|
||||
_verify_genereate_responses(
|
||||
responses, expected_status_and_error_messages, e2e_latencies
|
||||
)
|
||||
|
||||
assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0]
|
||||
|
||||
def test_priority_scheduling_preemption_below_threshold_validation(self):
|
||||
"""Verify running requests are not preempted by requests with priorities below preemption threshold"""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 0,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
},
|
||||
{
|
||||
"priority": 5,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(200, None),
|
||||
]
|
||||
|
||||
e2e_latencies = []
|
||||
_verify_genereate_responses(
|
||||
responses, expected_status_and_error_messages, e2e_latencies
|
||||
)
|
||||
|
||||
assert e2e_latencies[0] < e2e_latencies[1]
|
||||
|
||||
|
||||
class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||
cls.stderr = open(STDERR_FILENAME, "w")
|
||||
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--max-running-requests", # Enforce max request concurrency is 2
|
||||
"2",
|
||||
"--max-queued-requests", # Enforce max queued request number is 3
|
||||
"3",
|
||||
"--enable-priority-scheduling", # Enable priority scheduling
|
||||
),
|
||||
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
_verify_max_running_requests_and_max_queued_request_validation(2, 3)
|
||||
cls.stdout.close()
|
||||
cls.stderr.close()
|
||||
os.remove(STDOUT_FILENAME)
|
||||
os.remove(STDERR_FILENAME)
|
||||
|
||||
def test_priority_scheduling_with_multiple_running_requests_preemption(self):
|
||||
"""Verify preempting a subset of running requests is safe."""
|
||||
|
||||
responses = asyncio.run(
|
||||
send_concurrent_generate_requests_with_custom_params(
|
||||
self.base_url,
|
||||
[
|
||||
{
|
||||
"priority": 10,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # finishes first
|
||||
{
|
||||
"priority": 5,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # preempted by fourth request, then finishes third
|
||||
{
|
||||
"priority": 15,
|
||||
"sampling_params": {"max_new_tokens": 10000},
|
||||
}, # preempt the first request
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
expected_status_and_error_messages = [
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
(200, None),
|
||||
]
|
||||
|
||||
_verify_genereate_responses(responses, expected_status_and_error_messages, [])
|
||||
|
||||
|
||||
def _verify_genereate_responses(
|
||||
responses: Tuple[int, Any, float],
|
||||
expected_code_and_error_message: Tuple[int, Any],
|
||||
e2e_latencies: List[Optional[float]],
|
||||
):
|
||||
"""
|
||||
Verify generate response results are as expected based on status code and response json object content.
|
||||
In addition, collects e2e latency info to verify scheduling and processing ordering.
|
||||
"""
|
||||
for got, expected in zip(responses, expected_code_and_error_message):
|
||||
got_status, got_json = got
|
||||
expected_status, expected_err_msg = expected
|
||||
|
||||
# Check status code is as expected
|
||||
assert got_status == expected_status
|
||||
|
||||
# Check error message content or fields' existence based on status code
|
||||
if got_status != 200:
|
||||
assert got_json["object"] == "error"
|
||||
assert got_json["message"] == expected_err_msg
|
||||
else:
|
||||
assert "object" not in got_json
|
||||
assert "message" not in got_json
|
||||
|
||||
# Collect e2e latencies for scheduling validation
|
||||
e2e_latencies.append(
|
||||
got_json["meta_info"]["e2e_latency"] if got_status == 200 else None
|
||||
)
|
||||
|
||||
|
||||
def _verify_max_running_requests_and_max_queued_request_validation(
|
||||
max_running_requests: int, max_queued_requests: int
|
||||
):
|
||||
"""Verify running request and queued request numbers based on server logs."""
|
||||
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
|
||||
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
|
||||
|
||||
with open(STDERR_FILENAME) as lines:
|
||||
for line in lines:
|
||||
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
|
||||
if rr_match:
|
||||
assert int(rr_match.group(1)) <= max_running_requests
|
||||
if qr_match:
|
||||
assert int(qr_match.group(1)) <= max_queued_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase):
|
||||
send_concurrent_generate_requests(self.base_url, num_requests=10)
|
||||
)
|
||||
|
||||
assert 200 in status_codes
|
||||
assert 503 in status_codes
|
||||
assert all(status_code in [200, 503] for status_code in status_codes)
|
||||
expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
|
||||
assert status_codes == expected_status_codes
|
||||
|
||||
def test_max_running_requests_and_max_queued_request_validation(self):
|
||||
"""Verify running request and queued request numbers based on server logs."""
|
||||
|
||||
@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
|
||||
|
||||
def test_init_with_cache_aware_policy(self):
|
||||
policy = SchedulePolicy(
|
||||
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
|
||||
policy="lpm",
|
||||
tree_cache=self.tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
|
||||
|
||||
def test_init_with_cache_agnostic_policy(self):
|
||||
policy = SchedulePolicy(
|
||||
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
|
||||
policy="fcfs",
|
||||
tree_cache=self.tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
||||
|
||||
@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
|
||||
policy="invalid",
|
||||
tree_cache=self.tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
|
||||
def test_init_with_disabled_cache(self):
|
||||
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
|
||||
policy = SchedulePolicy(
|
||||
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
|
||||
policy="lpm",
|
||||
tree_cache=disabled_tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
||||
|
||||
@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
|
||||
]
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
|
||||
policy="fcfs",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if FCFS keeps the original order
|
||||
@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
|
||||
self.assertEqual(waiting_queue[1].rid, 3)
|
||||
self.assertEqual(waiting_queue[2].rid, 2)
|
||||
|
||||
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
|
||||
tree_cache = RadixCache(None, None, False)
|
||||
|
||||
waiting_queue = [
|
||||
Req(1, "a b", [1, 2], SamplingParams()),
|
||||
Req(3, "a b c", [1, 2, 3], SamplingParams()),
|
||||
Req(2, "a", [1], SamplingParams()),
|
||||
]
|
||||
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
|
||||
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
|
||||
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="fcfs",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=True,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if priority enabled fcfs ordering is applied.
|
||||
self.assertEqual(waiting_queue[0].rid, 1)
|
||||
self.assertEqual(waiting_queue[1].rid, 2)
|
||||
self.assertEqual(waiting_queue[2].rid, 3)
|
||||
|
||||
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
|
||||
self,
|
||||
):
|
||||
tree_cache = RadixCache(None, None, False)
|
||||
|
||||
waiting_queue = [
|
||||
Req(1, "a b", [1, 2], SamplingParams()),
|
||||
Req(3, "a b c", [1, 2, 3], SamplingParams()),
|
||||
Req(2, "a", [1], SamplingParams()),
|
||||
]
|
||||
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
|
||||
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
|
||||
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="fcfs",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=True,
|
||||
schedule_low_priority_values_first=True,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if priority enabled fcfs ordering is applied.
|
||||
self.assertEqual(waiting_queue[0].rid, 1)
|
||||
self.assertEqual(waiting_queue[1].rid, 2)
|
||||
self.assertEqual(waiting_queue[2].rid, 3)
|
||||
|
||||
def test_calc_priority_longest_output_first_scheduling(self):
|
||||
tree_cache = RadixCache(None, None, False)
|
||||
|
||||
waiting_queue = [
|
||||
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
|
||||
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
|
||||
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
|
||||
]
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="lof",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=False,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if priority enabled fcfs ordering is applied.
|
||||
self.assertEqual(waiting_queue[0].rid, 1)
|
||||
self.assertEqual(waiting_queue[1].rid, 2)
|
||||
self.assertEqual(waiting_queue[2].rid, 3)
|
||||
|
||||
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
|
||||
tree_cache = RadixCache(None, None, False)
|
||||
|
||||
waiting_queue = [
|
||||
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
|
||||
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
|
||||
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
|
||||
]
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="lof",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=True,
|
||||
schedule_low_priority_values_first=False,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if priority enabled fcfs ordering is applied.
|
||||
self.assertEqual(waiting_queue[0].rid, 1)
|
||||
self.assertEqual(waiting_queue[1].rid, 2)
|
||||
self.assertEqual(waiting_queue[2].rid, 3)
|
||||
|
||||
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
|
||||
self,
|
||||
):
|
||||
tree_cache = RadixCache(None, None, False)
|
||||
|
||||
waiting_queue = [
|
||||
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
|
||||
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
|
||||
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
|
||||
]
|
||||
|
||||
policy = SchedulePolicy(
|
||||
policy="lof",
|
||||
tree_cache=tree_cache,
|
||||
enable_hierarchical_cache=True,
|
||||
enable_priority_scheduling=True,
|
||||
schedule_low_priority_values_first=True,
|
||||
)
|
||||
policy.calc_priority(waiting_queue)
|
||||
# Check if priority enabled fcfs ordering is applied.
|
||||
self.assertEqual(waiting_queue[0].rid, 1)
|
||||
self.assertEqual(waiting_queue[1].rid, 2)
|
||||
self.assertEqual(waiting_queue[2].rid, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user