feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

This commit is contained in:
harrisonlimh
2025-09-16 17:10:10 -07:00
committed by GitHub
parent f949ad5794
commit 14fdd52740
16 changed files with 822 additions and 71 deletions

View File

@@ -243,6 +243,13 @@ class Scheduler(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.enable_priority_scheduling = server_args.enable_priority_scheduling
self.schedule_low_priority_values_first = (
server_args.schedule_low_priority_values_first
)
self.priority_scheduling_preemption_threshold = (
server_args.priority_scheduling_preemption_threshold
)
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
@@ -487,7 +494,12 @@ class Scheduler(
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
# Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
@@ -1150,20 +1162,6 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
@@ -1233,6 +1231,7 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
@@ -1382,6 +1381,9 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
@@ -1408,7 +1410,70 @@ class Scheduler(
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else:
self.waiting_queue.extend(reqs)
for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None:
if self.schedule_low_priority_values_first:
req.priority = sys.maxsize
else:
req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq(
req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
if (
self.max_queued_requests is None
or len(self.waiting_queue) + 1 <= self.max_queued_requests
):
return False
# Reject the incoming request by default.
req_to_abort = recv_req
message = "The request queue is full."
if self.enable_priority_scheduling:
# With priority scheduling, consider aboritng an existing request based on the priority.
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
direction * recv_req.priority < direction * candidate_req.priority
)
if abort_existing_req:
self.waiting_queue.pop(idx)
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
AbortReq(
req_to_abort.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message,
},
)
)
return req_to_abort.rid == recv_req.rid
def handle_embedding_request(
self,
@@ -1420,6 +1485,7 @@ class Scheduler(
recv_req.input_ids,
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
)
req.tokenizer = self.tokenizer
@@ -1680,6 +1746,10 @@ class Scheduler(
if self.grammar_queue:
self.move_ready_grammar_requests()
if self.try_preemption:
# Reset batch_is_full to try preemption with a prefill adder.
self.running_batch.batch_is_full = False
# Handle the cases where prefill is not allowed
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1692,7 +1762,11 @@ class Scheduler(
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
if (
self.get_num_allocatable_reqs(running_bs) <= 0
and not self.chunked_req
and not self.try_preemption
):
self.running_batch.batch_is_full = True
return None
@@ -1712,6 +1786,7 @@ class Scheduler(
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
if self.chunked_req is not None:
@@ -1732,15 +1807,19 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True
if self.running_batch.batch_is_full:
if not self.try_preemption:
break
if not adder.preempt_to_schedule(req, self.server_args):
break
if self.enable_hicache_storage:
@@ -1777,6 +1856,8 @@ class Scheduler(
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
if adder.new_chunked_req is not None:
assert self.chunked_req is None