From 14fdd52740fb19056ebecaf20863c0885f5ad26b Mon Sep 17 00:00:00 2001 From: harrisonlimh <97203667+harrisonlimh@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:10:10 -0700 Subject: [PATCH] feat: add priority based scheduling with priority based request acceptance and preemption (#8746) --- .../sglang/srt/entrypoints/openai/protocol.py | 6 + .../srt/entrypoints/openai/serving_chat.py | 1 + .../entrypoints/openai/serving_completions.py | 1 + .../entrypoints/openai/serving_embedding.py | 1 + python/sglang/srt/managers/io_struct.py | 5 + python/sglang/srt/managers/schedule_batch.py | 72 ++-- python/sglang/srt/managers/schedule_policy.py | 126 ++++++- python/sglang/srt/managers/scheduler.py | 117 +++++- .../sglang/srt/managers/tokenizer_manager.py | 2 + python/sglang/srt/managers/tp_worker.py | 4 +- python/sglang/srt/server_args.py | 30 +- python/sglang/test/test_utils.py | 37 +- test/srt/run_suite.py | 1 + test/srt/test_priority_scheduling.py | 339 ++++++++++++++++++ test/srt/test_request_queue_validation.py | 5 +- test/srt/test_schedule_policy.py | 146 +++++++- 16 files changed, 822 insertions(+), 71 deletions(-) create mode 100644 test/srt/test_priority_scheduling.py diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 8111f1939..23830d86c 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -228,6 +228,8 @@ class CompletionRequest(BaseModel): # For request id rid: Optional[Union[List[str], str]] = None + # Priority for the request + priority: Optional[int] = None # For customer metric labels customer_labels: Optional[Dict[str, str]] = None @@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel): # For request id rid: Optional[Union[List[str], str]] = None + # Priority for the request + priority: Optional[int] = None # For PD disaggregation bootstrap_host: Optional[Union[List[str], str]] = None @@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel): # The request id. rid: Optional[Union[List[str], str]] = None + # Priority for the request + priority: Optional[int] = None class EmbeddingObject(BaseModel): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d132c7bed..8bd57fc9e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + priority=request.priority, customer_labels=customer_labels, ) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 68b4f97b4..6aa4fe19e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + priority=request.priority, customer_labels=customer_labels, ) diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 6500915c1..7340a72f2 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): adapted_request = EmbeddingReqInput( **prompt_kwargs, rid=request.rid, + priority=request.priority, ) return adapted_request, request diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 16b87e164..c479f6d54 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -570,6 +570,7 @@ class TokenizedGenerateReqInput: token_ids_logprob: List[int] # Whether to stream output stream: bool + # Whether to return hidden states return_hidden_states: bool = False @@ -656,6 +657,8 @@ class EmbeddingReqInput: modalities: Optional[List[str]] = None # For cross-encoder requests is_cross_encoder_request: bool = False + # Priority for the request + priority: Optional[int] = None # For background responses (OpenAI responses API) background: bool = False @@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput: data_parallel_rank: Optional[int] = None # For dp balance dp_balance_id: int = -1 + # Priority for the request + priority: Optional[int] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9402e723f..c60864d64 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -453,6 +453,7 @@ class Req: bootstrap_room: Optional[int] = None, data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, + priority: Optional[int] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None, ): # Input and output info @@ -504,6 +505,7 @@ class Req: self.stream = stream self.eos_token_ids = eos_token_ids self.vocab_size = vocab_size + self.priority = priority # For incremental decoding # ----- | --------- read_ids -------| @@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) - - if server_args.disaggregation_mode == "decode": - req.offload_kv_cache( - self.req_to_token_pool, self.token_to_kv_pool_allocator - ) - - if isinstance(self.tree_cache, ChunkCache): - # ChunkCache does not have eviction - token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] - self.token_to_kv_pool_allocator.free(token_indices) - self.req_to_token_pool.free(req.req_pool_idx) - else: - # TODO: apply more fine-grained retraction - last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size - token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] - self.token_to_kv_pool_allocator.free(token_indices) - self.req_to_token_pool.free(req.req_pool_idx) - - # release the last node - if self.is_hybrid: - self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) - else: - self.tree_cache.dec_lock_ref(req.last_node) - - req.reset_for_retract() + self.release_req(idx, len(sorted_indices), server_args) if len(retracted_reqs) == 0: # Corner case: only one request left @@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): return retracted_reqs, new_estimate_ratio + def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): + req = self.reqs[idx] + seq_lens_cpu = self.seq_lens.cpu().numpy() + + if server_args.disaggregation_mode == "decode": + req.offload_kv_cache( + self.req_to_token_pool, self.token_to_kv_pool_allocator + ) + if isinstance(self.tree_cache, ChunkCache): + # ChunkCache does not have eviction + token_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : seq_lens_cpu[idx] + ] + self.token_to_kv_pool_allocator.free(token_indices) + self.req_to_token_pool.free(req.req_pool_idx) + else: + # TODO: apply more fine-grained retraction + last_uncached_pos = ( + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size + token_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] + self.token_to_kv_pool_allocator.free(token_indices) + self.req_to_token_pool.free(req.req_pool_idx) + + # release the last node + if self.is_hybrid: + self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + else: + self.tree_cache.dec_lock_ref(req.last_node) + + # NOTE(lsyin): we should use the newly evictable memory instantly. + num_tokens = remaing_req_count * global_config.retract_decode_steps + self._evict_tree_cache_if_needed(num_tokens) + + req.reset_for_retract() + def prepare_encoder_info_decode(self): # Reset the encoder cached status self.encoder_cached = [True] * len(self.reqs) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 0a3723e0b..3e8877faf 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator @@ -82,10 +83,14 @@ class SchedulePolicy: policy: str, tree_cache: BasePrefixCache, enable_hierarchical_cache: bool, + enable_priority_scheduling: bool, + schedule_low_priority_values_first: bool, ): self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache self.enable_hierarchical_cache = enable_hierarchical_cache + self.enable_priority_scheduling = enable_priority_scheduling + self.schedule_low_priority_values_first = schedule_low_priority_values_first # It is used to find the matching prefix for in-batch prefix caching. self.waiting_queue_radix_tree = RadixCache( @@ -97,7 +102,10 @@ class SchedulePolicy: def calc_priority(self, waiting_queue: List[Req]) -> bool: if self.policy == CacheAgnosticPolicy.FCFS: - # A shortcut for FCFS + if self.enable_priority_scheduling: + SchedulePolicy._sort_by_priority_and_fcfs( + waiting_queue, self.schedule_low_priority_values_first + ) return False policy = self._determine_active_policy(waiting_queue) @@ -120,12 +128,15 @@ class SchedulePolicy: if policy == CacheAgnosticPolicy.FCFS: pass elif policy == CacheAgnosticPolicy.LOF: - SchedulePolicy._sort_by_longest_output(waiting_queue) + SchedulePolicy._sort_by_longest_output( + waiting_queue, + self.enable_priority_scheduling, + self.schedule_low_priority_values_first, + ) elif policy == CacheAgnosticPolicy.RANDOM: SchedulePolicy._sort_randomly(waiting_queue) else: raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}") - return prefix_computed def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: @@ -231,15 +242,39 @@ class SchedulePolicy: ) @staticmethod - def _sort_by_longest_output(waiting_queue: List[Req]) -> None: - """Sorts the waiting queue based on the longest output (max_new_tokens).""" - waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + def _sort_by_longest_output( + waiting_queue: List[Req], + enable_priority_scheduling: bool, + schedule_low_priority_values_first: bool, + ) -> None: + """Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first.""" + if enable_priority_scheduling: + if schedule_low_priority_values_first: + waiting_queue.sort( + key=lambda x: (x.priority, -x.sampling_params.max_new_tokens) + ) + else: + waiting_queue.sort( + key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens) + ) + else: + waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) @staticmethod def _sort_randomly(waiting_queue: List[Req]) -> None: """Shuffles the waiting queue randomly.""" random.shuffle(waiting_queue) + @staticmethod + def _sort_by_priority_and_fcfs( + waiting_queue: List[Req], schedule_low_priority_values_first: bool + ) -> None: + """Sorts the waiting queue based on the request priority then received titmestamp.""" + if schedule_low_priority_values_first: + waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start)) + else: + waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start)) + @staticmethod def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: for child in cur_node.children.values(): @@ -279,6 +314,7 @@ class PrefillAdder: rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, + priority_scheduling_preemption_threshold: int = 0, ): self.page_size = page_size self.tree_cache = tree_cache @@ -295,6 +331,7 @@ class PrefillAdder: self.req_states = None self.can_run_list = [] + self.preempt_list = [] self.new_chunked_req = None self.log_hit_tokens = 0 # TODO(lsyin): report the real input tokens excluding page alignment @@ -303,11 +340,7 @@ class PrefillAdder: if running_batch is not None: self.rem_total_token_offset += sum( [ - min( - (r.sampling_params.max_new_tokens - len(r.output_ids)), - CLIP_MAX_NEW_TOKENS, - ) - * self.new_token_ratio + self._get_running_request_total_token_offset(r) for r in running_batch.reqs ] ) @@ -316,6 +349,19 @@ class PrefillAdder: self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) + self.priority_scheduling_preemption_threshold = ( + priority_scheduling_preemption_threshold + ) + + def _get_running_request_total_token_offset(self, req: Req) -> int: + return ( + min( + (req.sampling_params.max_new_tokens - len(req.output_ids)), + CLIP_MAX_NEW_TOKENS, + ) + * self.new_token_ratio + ) + @property def rem_total_tokens(self): if self.is_hybrid: @@ -568,3 +614,61 @@ class PrefillAdder: self._update_prefill_budget(prefix_len, trunc_len, 0) return self.budget_state() + + def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool: + """ + Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified. + Returns True if preemption was committed, and the new request can be scheduled. + """ + # Iterate running requests to find preemptible requests + if server_args.schedule_low_priority_values_first: + sorted_running_reqs = sorted( + self.running_batch.reqs, + key=lambda x: (-x.priority, -x.queue_time_start), + ) + else: + sorted_running_reqs = sorted( + self.running_batch.reqs, + key=lambda x: (x.priority, -x.queue_time_start), + ) + preemptible_reqs = [] + min_tokens_to_remove = ( + req.extend_input_len + + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + - self.rem_total_tokens + ) + for running_req in sorted_running_reqs: + if running_req in self.preempt_list: + continue + # Priority difference needs to meet the threshold to be preemptible. + priority_diff = req.priority - running_req.priority + if server_args.schedule_low_priority_values_first: + priority_diff *= -1 + if priority_diff > self.priority_scheduling_preemption_threshold: + preemptible_reqs.append(running_req) + min_tokens_to_remove -= self._get_running_request_total_token_offset( + running_req + ) + + # Check max token count limit can be met + if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0: + return False + + # Preempt running requests. Release allocated resources for immediate usage. + preemptible_reqs = set(preemptible_reqs) + keep_indices = [] + release_counter = 0 + for i, running_req in enumerate(self.running_batch.reqs): + if running_req in preemptible_reqs: + self.rem_total_token_offset -= ( + self._get_running_request_total_token_offset(req) + ) + release_counter += 1 + self.running_batch.release_req( + i, len(self.running_batch.reqs) - release_counter, server_args + ) + else: + keep_indices.append(i) + self.running_batch.filter_batch(keep_indices=keep_indices) + self.preempt_list.extend(preemptible_reqs) + return True diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f2697e75f..83e6b45cb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -243,6 +243,13 @@ class Scheduler( self.pp_size = server_args.pp_size self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy + self.enable_priority_scheduling = server_args.enable_priority_scheduling + self.schedule_low_priority_values_first = ( + server_args.schedule_low_priority_values_first + ) + self.priority_scheduling_preemption_threshold = ( + server_args.priority_scheduling_preemption_threshold + ) self.enable_lora = server_args.enable_lora self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = not server_args.disable_overlap_schedule @@ -487,7 +494,12 @@ class Scheduler( self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache, + self.enable_priority_scheduling, + self.schedule_low_priority_values_first, ) + # Enable preemption for priority scheduling. + self.try_preemption = self.enable_priority_scheduling + assert ( server_args.schedule_conservativeness >= 0 ), "Invalid schedule_conservativeness" @@ -1150,20 +1162,6 @@ class Scheduler( self.return_health_check_ct += 1 continue - # If it is a work request, accept or reject the request based on the request queue size. - if is_work_request(recv_req): - if len(self.waiting_queue) + 1 > self.max_queued_requests: - abort_req = AbortReq( - recv_req.rid, - finished_reason={ - "type": "abort", - "status_code": HTTPStatus.SERVICE_UNAVAILABLE, - "message": "The request queue is full.", - }, - ) - self.send_to_tokenizer.send_pyobj(abort_req) - continue - # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request. if isinstance(recv_req, MultiTokenizerWrapper): worker_id = recv_req.worker_id @@ -1233,6 +1231,7 @@ class Scheduler( bootstrap_room=recv_req.bootstrap_room, data_parallel_rank=recv_req.data_parallel_rank, vocab_size=self.model_config.vocab_size, + priority=recv_req.priority, metrics_collector=( self.metrics_collector if self.enable_metrics else None ), @@ -1382,6 +1381,9 @@ class Scheduler( elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req) else: + self._set_or_validate_priority(req) + if self._abort_on_queued_limit(req): + return self._prefetch_kvcache(req) self.waiting_queue.append(req) trace_slice_end("process req", req.rid, auto_next_anon=True) @@ -1408,7 +1410,70 @@ class Scheduler( # If this is a decode server, we put the request to the decode pending prealloc queue self.disagg_decode_prealloc_queue.extend(reqs, is_retracted) else: - self.waiting_queue.extend(reqs) + for req in reqs: + self._set_or_validate_priority(req) + if not self._abort_on_queued_limit(req): + self.waiting_queue.append(req) + + def _set_or_validate_priority(self, req: Req): + """Set the default priority value, or abort the request based on the priority scheduling mode.""" + if self.enable_priority_scheduling and req.priority is None: + if self.schedule_low_priority_values_first: + req.priority = sys.maxsize + else: + req.priority = -sys.maxsize - 1 + elif not self.enable_priority_scheduling and req.priority is not None: + abort_req = AbortReq( + req.rid, + finished_reason={ + "type": "abort", + "status_code": HTTPStatus.SERVICE_UNAVAILABLE, + "message": "Using priority is disabled for this server. Please send a new request without a priority.", + }, + ) + self.send_to_tokenizer.send_pyobj(abort_req) + + def _abort_on_queued_limit(self, recv_req: Req) -> bool: + """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted.""" + if ( + self.max_queued_requests is None + or len(self.waiting_queue) + 1 <= self.max_queued_requests + ): + return False + + # Reject the incoming request by default. + req_to_abort = recv_req + message = "The request queue is full." + if self.enable_priority_scheduling: + # With priority scheduling, consider aboritng an existing request based on the priority. + # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority. + # max(...) + (direction * priority, queue_time_start) picks the least-preferred request. + # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better. + direction = 1 if self.schedule_low_priority_values_first else -1 + key_fn = lambda item: ( + direction * item[1].priority, + item[1].queue_time_start, + ) + idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn) + abort_existing_req = ( + direction * recv_req.priority < direction * candidate_req.priority + ) + if abort_existing_req: + self.waiting_queue.pop(idx) + req_to_abort = candidate_req + message = "The request is aborted by a higher priority request." + + self.send_to_tokenizer.send_pyobj( + AbortReq( + req_to_abort.rid, + finished_reason={ + "type": "abort", + "status_code": HTTPStatus.SERVICE_UNAVAILABLE, + "message": message, + }, + ) + ) + return req_to_abort.rid == recv_req.rid def handle_embedding_request( self, @@ -1420,6 +1485,7 @@ class Scheduler( recv_req.input_ids, recv_req.sampling_params, token_type_ids=recv_req.token_type_ids, + priority=recv_req.priority, ) req.tokenizer = self.tokenizer @@ -1680,6 +1746,10 @@ class Scheduler( if self.grammar_queue: self.move_ready_grammar_requests() + if self.try_preemption: + # Reset batch_is_full to try preemption with a prefill adder. + self.running_batch.batch_is_full = False + # Handle the cases where prefill is not allowed if ( self.running_batch.batch_is_full or len(self.waiting_queue) == 0 @@ -1692,7 +1762,11 @@ class Scheduler( # as the space for the chunked request has just been released. # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict. # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak. - if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req: + if ( + self.get_num_allocatable_reqs(running_bs) <= 0 + and not self.chunked_req + and not self.try_preemption + ): self.running_batch.batch_is_full = True return None @@ -1712,6 +1786,7 @@ class Scheduler( self.max_prefill_tokens, self.chunked_prefill_size, running_bs if self.is_mixed_chunk else 0, + self.priority_scheduling_preemption_threshold, ) if self.chunked_req is not None: @@ -1732,15 +1807,19 @@ class Scheduler( self.running_batch.batch_is_full = True break + running_bs = len(self.running_batch.reqs) - len(adder.preempt_list) if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs): self.running_batch.batch_is_full = True - break - if self.disaggregation_mode == DisaggregationMode.PREFILL: # In prefill mode, prealloc queue and transfer queue can also take memory, # so we need to check if the available size for the actual available size. if len(adder.can_run_list) >= self.req_to_token_pool.available_size(): self.running_batch.batch_is_full = True + + if self.running_batch.batch_is_full: + if not self.try_preemption: + break + if not adder.preempt_to_schedule(req, self.server_args): break if self.enable_hicache_storage: @@ -1777,6 +1856,8 @@ class Scheduler( self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) ] + if adder.preempt_list: + self._extend_requests_to_queue(adder.preempt_list) if adder.new_chunked_req is not None: assert self.chunked_req is None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5674bb475..6e9884dc3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, data_parallel_rank=obj.data_parallel_rank, + priority=obj.priority, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): mm_inputs, token_type_ids, sampling_params, + priority=obj.priority, ) return tokenized_obj diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 059813f83..7bed87592 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -149,8 +149,8 @@ class TpModelWorker: assert self.max_running_requests > 0, "max_running_request is zero" self.max_queued_requests = server_args.max_queued_requests assert ( - self.max_queued_requests > 0 - ), "max_queued_requests is zero. We need to be at least 1 to schedule a request." + self.max_queued_requests is None or self.max_queued_requests >= 1 + ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled." self.max_req_len = min( self.model_config.context_len - 1, self.max_total_num_tokens - 1, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 32853b386..8ee1e8f27 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -172,11 +172,14 @@ class ServerArgs: # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None - max_queued_requests: Optional[int] = sys.maxsize + max_queued_requests: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: Optional[int] = None max_prefill_tokens: int = 16384 schedule_policy: str = "fcfs" + enable_priority_scheduling: bool = False + schedule_low_priority_values_first: bool = False + priority_scheduling_preemption_threshold: int = 10 schedule_conservativeness: float = 1.0 page_size: Optional[int] = None hybrid_kvcache_ratio: Optional[float] = None @@ -1166,6 +1169,24 @@ class ServerArgs: choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"], help="The scheduling policy of the requests.", ) + parser.add_argument( + "--enable-priority-scheduling", + action="store_true", + default=ServerArgs.enable_priority_scheduling, + help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.", + ) + parser.add_argument( + "--schedule-low-priority-values-first", + action="store_true", + default=ServerArgs.schedule_low_priority_values_first, + help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.", + ) + parser.add_argument( + "--priority-scheduling-preemption-threshold", + type=int, + default=ServerArgs.priority_scheduling_preemption_threshold, + help="Minimum difference in priorities for an incoming request to have to preempt running request(s).", + ) parser.add_argument( "--schedule-conservativeness", type=float, @@ -2455,6 +2476,13 @@ class ServerArgs: "--generation-tokens-buckets", self.generation_tokens_buckets ) + # Check scheduling policy + if self.enable_priority_scheduling: + assert self.schedule_policy in [ + "fcfs", + "lof", + ], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported." + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 90a8ef5e1..6a69fed00 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path from types import SimpleNamespace -from typing import Awaitable, Callable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import aiohttp import numpy as np @@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests( return await asyncio.gather(*tasks) +async def send_concurrent_generate_requests_with_custom_params( + base_url: str, + custom_params: List[dict[str, Any]], +) -> Tuple[int, Any]: + """Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests.""" + + base_payload = { + "text": """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 50, + }, + } + + async def async_generate_with_priority(req): + async with aiohttp.ClientSession() as session: + async with session.post( + f"{base_url}/generate", + json=req, + ) as response: + resp_json = await response.json() + return (response.status, resp_json) + + tasks = [] + for c in custom_params: + req = base_payload.copy() + req.update(c) + tasks.append(asyncio.create_task(async_generate_with_priority(req))) + return await asyncio.gather(*tasks) + + class CustomTestCase(unittest.TestCase): def _callTestMethod(self, method): max_retry = int( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 7b3632bc9..912d32801 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -95,6 +95,7 @@ suites = { TestFile("test_original_logprobs.py", 200), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), + TestFile("test_priority_scheduling.py", 100), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 105), TestFile("test_regex_constrained.py", 64), diff --git a/test/srt/test_priority_scheduling.py b/test/srt/test_priority_scheduling.py new file mode 100644 index 000000000..befde130e --- /dev/null +++ b/test/srt/test_priority_scheduling.py @@ -0,0 +1,339 @@ +import asyncio +import os +import re +import unittest +from typing import Any, Awaitable, Callable, List, Optional, Tuple + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + STDERR_FILENAME, + STDOUT_FILENAME, + CustomTestCase, + popen_launch_server, + send_concurrent_generate_requests_with_custom_params, +) + + +class TestPriorityScheduling(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + cls.stdout = open(STDOUT_FILENAME, "w") + cls.stderr = open(STDERR_FILENAME, "w") + + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--max-running-requests", # Enforce max request concurrency is 1 + "1", + "--max-queued-requests", # Enforce max queued request number is 3 + "3", + "--enable-priority-scheduling", # Enable priority scheduling + ), + return_stdout_stderr=(cls.stdout, cls.stderr), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + _verify_max_running_requests_and_max_queued_request_validation(1, 3) + cls.stdout.close() + cls.stderr.close() + os.remove(STDOUT_FILENAME) + os.remove(STDERR_FILENAME) + + def test_priority_scheduling_request_ordering_validation(self): + """Verify pending requests are ordered by priority and received timestamp.""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 0, + "sampling_params": {"max_new_tokens": 10000}, + }, # starts being processed first + {"priority": 1}, # third + {"priority": 1}, # fourth + {"priority": 2}, # second + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (200, None), + (200, None), + (200, None), + ] + + e2e_latencies = [] + _verify_genereate_responses( + responses, expected_status_and_error_messages, e2e_latencies + ) + assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2] + + def test_priority_scheduling_existing_requests_abortion_validation(self): + """Verify lower priority requests are aborted when incoming requests have higher priority""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 1, + "sampling_params": {"max_new_tokens": 10000}, + }, # starts being processed first and holds the running queue capacity + {"priority": 2}, # aborted by request 5 + {"priority": 3}, # aborted by request 6 + {"priority": 4}, # aborted by request 7 + {"priority": 5}, # fourth + {"priority": 6}, # third + {"priority": 7}, # second + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (503, "The request is aborted by a higher priority request."), + (503, "The request is aborted by a higher priority request."), + (503, "The request is aborted by a higher priority request."), + (200, None), + (200, None), + (200, None), + ] + + e2e_latencies = [] + _verify_genereate_responses( + responses, expected_status_and_error_messages, e2e_latencies + ) + assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4] + + def test_priority_scheduling_incoming_request_rejection_validation(self): + """Verify incoming requests are rejected when existing requests have higher priority""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 7, + "sampling_params": {"max_new_tokens": 10000}, + }, # starts being processed first and holds the running queue capacity + {"priority": 6}, # second + {"priority": 5}, # third + {"priority": 4}, # fourth + {"priority": 3}, # rejected + {"priority": 2}, # rejected + {"priority": 1}, # rejected + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (200, None), + (200, None), + (200, None), + (503, "The request queue is full."), + (503, "The request queue is full."), + (503, "The request queue is full."), + ] + + e2e_latencies = [] + _verify_genereate_responses( + responses, expected_status_and_error_messages, e2e_latencies + ) + assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3] + + def test_priority_scheduling_preemption_meeting_threshold_validation(self): + """Verify running requests are preempted by requests with priorities meeting the preemption threshold""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 0, + "sampling_params": {"max_new_tokens": 10000}, + }, # starts being processed first then preempted or pushed by later requests, and finishes last. + { + "priority": 10, + "sampling_params": {"max_new_tokens": 10000}, + }, # scheduled after the third request, and finishes second. + { + "priority": 20, + "sampling_params": {"max_new_tokens": 10000}, + }, # finishes first. + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (200, None), + (200, None), + ] + + e2e_latencies = [] + _verify_genereate_responses( + responses, expected_status_and_error_messages, e2e_latencies + ) + + assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0] + + def test_priority_scheduling_preemption_below_threshold_validation(self): + """Verify running requests are not preempted by requests with priorities below preemption threshold""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 0, + "sampling_params": {"max_new_tokens": 10000}, + }, + { + "priority": 5, + "sampling_params": {"max_new_tokens": 10000}, + }, + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (200, None), + ] + + e2e_latencies = [] + _verify_genereate_responses( + responses, expected_status_and_error_messages, e2e_latencies + ) + + assert e2e_latencies[0] < e2e_latencies[1] + + +class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + cls.stdout = open(STDOUT_FILENAME, "w") + cls.stderr = open(STDERR_FILENAME, "w") + + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--max-running-requests", # Enforce max request concurrency is 2 + "2", + "--max-queued-requests", # Enforce max queued request number is 3 + "3", + "--enable-priority-scheduling", # Enable priority scheduling + ), + return_stdout_stderr=(cls.stdout, cls.stderr), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + _verify_max_running_requests_and_max_queued_request_validation(2, 3) + cls.stdout.close() + cls.stderr.close() + os.remove(STDOUT_FILENAME) + os.remove(STDERR_FILENAME) + + def test_priority_scheduling_with_multiple_running_requests_preemption(self): + """Verify preempting a subset of running requests is safe.""" + + responses = asyncio.run( + send_concurrent_generate_requests_with_custom_params( + self.base_url, + [ + { + "priority": 10, + "sampling_params": {"max_new_tokens": 10000}, + }, # finishes first + { + "priority": 5, + "sampling_params": {"max_new_tokens": 10000}, + }, # preempted by fourth request, then finishes third + { + "priority": 15, + "sampling_params": {"max_new_tokens": 10000}, + }, # preempt the first request + ], + ) + ) + + expected_status_and_error_messages = [ + (200, None), + (200, None), + (200, None), + (200, None), + ] + + _verify_genereate_responses(responses, expected_status_and_error_messages, []) + + +def _verify_genereate_responses( + responses: Tuple[int, Any, float], + expected_code_and_error_message: Tuple[int, Any], + e2e_latencies: List[Optional[float]], +): + """ + Verify generate response results are as expected based on status code and response json object content. + In addition, collects e2e latency info to verify scheduling and processing ordering. + """ + for got, expected in zip(responses, expected_code_and_error_message): + got_status, got_json = got + expected_status, expected_err_msg = expected + + # Check status code is as expected + assert got_status == expected_status + + # Check error message content or fields' existence based on status code + if got_status != 200: + assert got_json["object"] == "error" + assert got_json["message"] == expected_err_msg + else: + assert "object" not in got_json + assert "message" not in got_json + + # Collect e2e latencies for scheduling validation + e2e_latencies.append( + got_json["meta_info"]["e2e_latency"] if got_status == 200 else None + ) + + +def _verify_max_running_requests_and_max_queued_request_validation( + max_running_requests: int, max_queued_requests: int +): + """Verify running request and queued request numbers based on server logs.""" + rr_pattern = re.compile(r"#running-req:\s*(\d+)") + qr_pattern = re.compile(r"#queue-req:\s*(\d+)") + + with open(STDERR_FILENAME) as lines: + for line in lines: + rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line) + if rr_match: + assert int(rr_match.group(1)) <= max_running_requests + if qr_match: + assert int(qr_match.group(1)) <= max_queued_requests + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_request_queue_validation.py b/test/srt/test_request_queue_validation.py index 2a9739a1c..7574f9059 100644 --- a/test/srt/test_request_queue_validation.py +++ b/test/srt/test_request_queue_validation.py @@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase): send_concurrent_generate_requests(self.base_url, num_requests=10) ) - assert 200 in status_codes - assert 503 in status_codes - assert all(status_code in [200, 503] for status_code in status_codes) + expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503] + assert status_codes == expected_status_codes def test_max_running_requests_and_max_queued_request_validation(self): """Verify running request and queued request numbers based on server logs.""" diff --git a/test/srt/test_schedule_policy.py b/test/srt/test_schedule_policy.py index 4a4f57b35..0e33b6b25 100644 --- a/test/srt/test_schedule_policy.py +++ b/test/srt/test_schedule_policy.py @@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase): def test_init_with_cache_aware_policy(self): policy = SchedulePolicy( - policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True + policy="lpm", + tree_cache=self.tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAwarePolicy.LPM) def test_init_with_cache_agnostic_policy(self): policy = SchedulePolicy( - policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True + policy="fcfs", + tree_cache=self.tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) @@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase): policy="invalid", tree_cache=self.tree_cache, enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, ) def test_init_with_disabled_cache(self): disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1) policy = SchedulePolicy( - policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True + policy="lpm", + tree_cache=disabled_tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) @@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase): ] policy = SchedulePolicy( - policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True + policy="fcfs", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) # Check if FCFS keeps the original order @@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase): self.assertEqual(waiting_queue[1].rid, 3) self.assertEqual(waiting_queue[2].rid, 2) + def test_calc_priority_priority_enabled_fcfs_scheduling(self): + tree_cache = RadixCache(None, None, False) + + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams()), + Req(3, "a b c", [1, 2, 3], SamplingParams()), + Req(2, "a", [1], SamplingParams()), + ] + waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1 + waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 + waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 + + policy = SchedulePolicy( + policy="fcfs", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=True, + schedule_low_priority_values_first=False, + ) + policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 2) + self.assertEqual(waiting_queue[2].rid, 3) + + def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first( + self, + ): + tree_cache = RadixCache(None, None, False) + + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams()), + Req(3, "a b c", [1, 2, 3], SamplingParams()), + Req(2, "a", [1], SamplingParams()), + ] + waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0 + waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 + waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 + + policy = SchedulePolicy( + policy="fcfs", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=True, + schedule_low_priority_values_first=True, + ) + policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 2) + self.assertEqual(waiting_queue[2].rid, 3) + + def test_calc_priority_longest_output_first_scheduling(self): + tree_cache = RadixCache(None, None, False) + + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)), + Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)), + Req(2, "a", [1], SamplingParams(max_new_tokens=100)), + ] + + policy = SchedulePolicy( + policy="lof", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=False, + schedule_low_priority_values_first=False, + ) + policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 2) + self.assertEqual(waiting_queue[2].rid, 3) + + def test_calc_priority_priority_enabled_longest_output_first_scheduling(self): + tree_cache = RadixCache(None, None, False) + + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1), + Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0), + Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0), + ] + + policy = SchedulePolicy( + policy="lof", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=True, + schedule_low_priority_values_first=False, + ) + policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 2) + self.assertEqual(waiting_queue[2].rid, 3) + + def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first( + self, + ): + tree_cache = RadixCache(None, None, False) + + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0), + Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1), + Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1), + ] + + policy = SchedulePolicy( + policy="lof", + tree_cache=tree_cache, + enable_hierarchical_cache=True, + enable_priority_scheduling=True, + schedule_low_priority_values_first=True, + ) + policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 2) + self.assertEqual(waiting_queue[2].rid, 3) + if __name__ == "__main__": unittest.main()