feat: add priority based scheduling with priority based request acceptance and preemption (#8746)
This commit is contained in:
@@ -243,6 +243,13 @@ class Scheduler(
|
||||
self.pp_size = server_args.pp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
||||
self.schedule_low_priority_values_first = (
|
||||
server_args.schedule_low_priority_values_first
|
||||
)
|
||||
self.priority_scheduling_preemption_threshold = (
|
||||
server_args.priority_scheduling_preemption_threshold
|
||||
)
|
||||
self.enable_lora = server_args.enable_lora
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
@@ -487,7 +494,12 @@ class Scheduler(
|
||||
self.schedule_policy,
|
||||
self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
self.enable_priority_scheduling,
|
||||
self.schedule_low_priority_values_first,
|
||||
)
|
||||
# Enable preemption for priority scheduling.
|
||||
self.try_preemption = self.enable_priority_scheduling
|
||||
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
), "Invalid schedule_conservativeness"
|
||||
@@ -1150,20 +1162,6 @@ class Scheduler(
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
# If it is a work request, accept or reject the request based on the request queue size.
|
||||
if is_work_request(recv_req):
|
||||
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
||||
abort_req = AbortReq(
|
||||
recv_req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "The request queue is full.",
|
||||
},
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
continue
|
||||
|
||||
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
||||
if isinstance(recv_req, MultiTokenizerWrapper):
|
||||
worker_id = recv_req.worker_id
|
||||
@@ -1233,6 +1231,7 @@ class Scheduler(
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
data_parallel_rank=recv_req.data_parallel_rank,
|
||||
vocab_size=self.model_config.vocab_size,
|
||||
priority=recv_req.priority,
|
||||
metrics_collector=(
|
||||
self.metrics_collector if self.enable_metrics else None
|
||||
),
|
||||
@@ -1382,6 +1381,9 @@ class Scheduler(
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
self._set_or_validate_priority(req)
|
||||
if self._abort_on_queued_limit(req):
|
||||
return
|
||||
self._prefetch_kvcache(req)
|
||||
self.waiting_queue.append(req)
|
||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||
@@ -1408,7 +1410,70 @@ class Scheduler(
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
for req in reqs:
|
||||
self._set_or_validate_priority(req)
|
||||
if not self._abort_on_queued_limit(req):
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _set_or_validate_priority(self, req: Req):
|
||||
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||
if self.enable_priority_scheduling and req.priority is None:
|
||||
if self.schedule_low_priority_values_first:
|
||||
req.priority = sys.maxsize
|
||||
else:
|
||||
req.priority = -sys.maxsize - 1
|
||||
elif not self.enable_priority_scheduling and req.priority is not None:
|
||||
abort_req = AbortReq(
|
||||
req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
||||
},
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
|
||||
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
||||
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
||||
if (
|
||||
self.max_queued_requests is None
|
||||
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
||||
):
|
||||
return False
|
||||
|
||||
# Reject the incoming request by default.
|
||||
req_to_abort = recv_req
|
||||
message = "The request queue is full."
|
||||
if self.enable_priority_scheduling:
|
||||
# With priority scheduling, consider aboritng an existing request based on the priority.
|
||||
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
||||
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
||||
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
||||
direction = 1 if self.schedule_low_priority_values_first else -1
|
||||
key_fn = lambda item: (
|
||||
direction * item[1].priority,
|
||||
item[1].queue_time_start,
|
||||
)
|
||||
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
||||
abort_existing_req = (
|
||||
direction * recv_req.priority < direction * candidate_req.priority
|
||||
)
|
||||
if abort_existing_req:
|
||||
self.waiting_queue.pop(idx)
|
||||
req_to_abort = candidate_req
|
||||
message = "The request is aborted by a higher priority request."
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(
|
||||
req_to_abort.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
)
|
||||
return req_to_abort.rid == recv_req.rid
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
@@ -1420,6 +1485,7 @@ class Scheduler(
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
token_type_ids=recv_req.token_type_ids,
|
||||
priority=recv_req.priority,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1680,6 +1746,10 @@ class Scheduler(
|
||||
if self.grammar_queue:
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
if self.try_preemption:
|
||||
# Reset batch_is_full to try preemption with a prefill adder.
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
||||
@@ -1692,7 +1762,11 @@ class Scheduler(
|
||||
# as the space for the chunked request has just been released.
|
||||
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
||||
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
||||
if (
|
||||
self.get_num_allocatable_reqs(running_bs) <= 0
|
||||
and not self.chunked_req
|
||||
and not self.try_preemption
|
||||
):
|
||||
self.running_batch.batch_is_full = True
|
||||
return None
|
||||
|
||||
@@ -1712,6 +1786,7 @@ class Scheduler(
|
||||
self.max_prefill_tokens,
|
||||
self.chunked_prefill_size,
|
||||
running_bs if self.is_mixed_chunk else 0,
|
||||
self.priority_scheduling_preemption_threshold,
|
||||
)
|
||||
|
||||
if self.chunked_req is not None:
|
||||
@@ -1732,15 +1807,19 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
||||
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
||||
# so we need to check if the available size for the actual available size.
|
||||
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
||||
self.running_batch.batch_is_full = True
|
||||
|
||||
if self.running_batch.batch_is_full:
|
||||
if not self.try_preemption:
|
||||
break
|
||||
if not adder.preempt_to_schedule(req, self.server_args):
|
||||
break
|
||||
|
||||
if self.enable_hicache_storage:
|
||||
@@ -1777,6 +1856,8 @@ class Scheduler(
|
||||
self.waiting_queue = [
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
if adder.preempt_list:
|
||||
self._extend_requests_to_queue(adder.preempt_list)
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
|
||||
Reference in New Issue
Block a user