feat: add priority based scheduling with priority based request acceptance and preemption (#8746)
This commit is contained in:
@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel):
|
||||
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
|
||||
@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
adapted_request = EmbeddingReqInput(
|
||||
**prompt_kwargs,
|
||||
rid=request.rid,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput:
|
||||
token_ids_logprob: List[int]
|
||||
# Whether to stream output
|
||||
stream: bool
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@@ -656,6 +657,8 @@ class EmbeddingReqInput:
|
||||
modalities: Optional[List[str]] = None
|
||||
# For cross-encoder requests
|
||||
is_cross_encoder_request: bool = False
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# For background responses (OpenAI responses API)
|
||||
background: bool = False
|
||||
@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput:
|
||||
data_parallel_rank: Optional[int] = None
|
||||
# For dp balance
|
||||
dp_balance_id: int = -1
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -453,6 +453,7 @@ class Req:
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
priority: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -504,6 +505,7 @@ class Req:
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.priority = priority
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = (
|
||||
len(req.prefix_indices) // server_args.page_size
|
||||
) * server_args.page_size
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
req.reset_for_retract()
|
||||
self.release_req(idx, len(sorted_indices), server_args)
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
# Corner case: only one request left
|
||||
@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = (
|
||||
len(req.prefix_indices) // server_args.page_size
|
||||
) * server_args.page_size
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
@@ -82,10 +83,14 @@ class SchedulePolicy:
|
||||
policy: str,
|
||||
tree_cache: BasePrefixCache,
|
||||
enable_hierarchical_cache: bool,
|
||||
enable_priority_scheduling: bool,
|
||||
schedule_low_priority_values_first: bool,
|
||||
):
|
||||
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
||||
self.tree_cache = tree_cache
|
||||
self.enable_hierarchical_cache = enable_hierarchical_cache
|
||||
self.enable_priority_scheduling = enable_priority_scheduling
|
||||
self.schedule_low_priority_values_first = schedule_low_priority_values_first
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
@@ -97,7 +102,10 @@ class SchedulePolicy:
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
if self.policy == CacheAgnosticPolicy.FCFS:
|
||||
# A shortcut for FCFS
|
||||
if self.enable_priority_scheduling:
|
||||
SchedulePolicy._sort_by_priority_and_fcfs(
|
||||
waiting_queue, self.schedule_low_priority_values_first
|
||||
)
|
||||
return False
|
||||
|
||||
policy = self._determine_active_policy(waiting_queue)
|
||||
@@ -120,12 +128,15 @@ class SchedulePolicy:
|
||||
if policy == CacheAgnosticPolicy.FCFS:
|
||||
pass
|
||||
elif policy == CacheAgnosticPolicy.LOF:
|
||||
SchedulePolicy._sort_by_longest_output(waiting_queue)
|
||||
SchedulePolicy._sort_by_longest_output(
|
||||
waiting_queue,
|
||||
self.enable_priority_scheduling,
|
||||
self.schedule_low_priority_values_first,
|
||||
)
|
||||
elif policy == CacheAgnosticPolicy.RANDOM:
|
||||
SchedulePolicy._sort_randomly(waiting_queue)
|
||||
else:
|
||||
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
||||
|
||||
return prefix_computed
|
||||
|
||||
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
||||
@@ -231,15 +242,39 @@ class SchedulePolicy:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
|
||||
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
|
||||
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||
def _sort_by_longest_output(
|
||||
waiting_queue: List[Req],
|
||||
enable_priority_scheduling: bool,
|
||||
schedule_low_priority_values_first: bool,
|
||||
) -> None:
|
||||
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
|
||||
if enable_priority_scheduling:
|
||||
if schedule_low_priority_values_first:
|
||||
waiting_queue.sort(
|
||||
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
|
||||
)
|
||||
else:
|
||||
waiting_queue.sort(
|
||||
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
|
||||
)
|
||||
else:
|
||||
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||
|
||||
@staticmethod
|
||||
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
||||
"""Shuffles the waiting queue randomly."""
|
||||
random.shuffle(waiting_queue)
|
||||
|
||||
@staticmethod
|
||||
def _sort_by_priority_and_fcfs(
|
||||
waiting_queue: List[Req], schedule_low_priority_values_first: bool
|
||||
) -> None:
|
||||
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
||||
if schedule_low_priority_values_first:
|
||||
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
|
||||
else:
|
||||
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
|
||||
|
||||
@staticmethod
|
||||
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
||||
for child in cur_node.children.values():
|
||||
@@ -279,6 +314,7 @@ class PrefillAdder:
|
||||
rem_input_tokens: int,
|
||||
rem_chunk_tokens: Optional[int],
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
priority_scheduling_preemption_threshold: int = 0,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self.tree_cache = tree_cache
|
||||
@@ -295,6 +331,7 @@ class PrefillAdder:
|
||||
|
||||
self.req_states = None
|
||||
self.can_run_list = []
|
||||
self.preempt_list = []
|
||||
self.new_chunked_req = None
|
||||
self.log_hit_tokens = 0
|
||||
# TODO(lsyin): report the real input tokens excluding page alignment
|
||||
@@ -303,11 +340,7 @@ class PrefillAdder:
|
||||
if running_batch is not None:
|
||||
self.rem_total_token_offset += sum(
|
||||
[
|
||||
min(
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
* self.new_token_ratio
|
||||
self._get_running_request_total_token_offset(r)
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
@@ -316,6 +349,19 @@ class PrefillAdder:
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
self.priority_scheduling_preemption_threshold = (
|
||||
priority_scheduling_preemption_threshold
|
||||
)
|
||||
|
||||
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
||||
return (
|
||||
min(
|
||||
(req.sampling_params.max_new_tokens - len(req.output_ids)),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
* self.new_token_ratio
|
||||
)
|
||||
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
if self.is_hybrid:
|
||||
@@ -568,3 +614,61 @@ class PrefillAdder:
|
||||
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
|
||||
"""
|
||||
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
|
||||
Returns True if preemption was committed, and the new request can be scheduled.
|
||||
"""
|
||||
# Iterate running requests to find preemptible requests
|
||||
if server_args.schedule_low_priority_values_first:
|
||||
sorted_running_reqs = sorted(
|
||||
self.running_batch.reqs,
|
||||
key=lambda x: (-x.priority, -x.queue_time_start),
|
||||
)
|
||||
else:
|
||||
sorted_running_reqs = sorted(
|
||||
self.running_batch.reqs,
|
||||
key=lambda x: (x.priority, -x.queue_time_start),
|
||||
)
|
||||
preemptible_reqs = []
|
||||
min_tokens_to_remove = (
|
||||
req.extend_input_len
|
||||
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||
- self.rem_total_tokens
|
||||
)
|
||||
for running_req in sorted_running_reqs:
|
||||
if running_req in self.preempt_list:
|
||||
continue
|
||||
# Priority difference needs to meet the threshold to be preemptible.
|
||||
priority_diff = req.priority - running_req.priority
|
||||
if server_args.schedule_low_priority_values_first:
|
||||
priority_diff *= -1
|
||||
if priority_diff > self.priority_scheduling_preemption_threshold:
|
||||
preemptible_reqs.append(running_req)
|
||||
min_tokens_to_remove -= self._get_running_request_total_token_offset(
|
||||
running_req
|
||||
)
|
||||
|
||||
# Check max token count limit can be met
|
||||
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
|
||||
return False
|
||||
|
||||
# Preempt running requests. Release allocated resources for immediate usage.
|
||||
preemptible_reqs = set(preemptible_reqs)
|
||||
keep_indices = []
|
||||
release_counter = 0
|
||||
for i, running_req in enumerate(self.running_batch.reqs):
|
||||
if running_req in preemptible_reqs:
|
||||
self.rem_total_token_offset -= (
|
||||
self._get_running_request_total_token_offset(req)
|
||||
)
|
||||
release_counter += 1
|
||||
self.running_batch.release_req(
|
||||
i, len(self.running_batch.reqs) - release_counter, server_args
|
||||
)
|
||||
else:
|
||||
keep_indices.append(i)
|
||||
self.running_batch.filter_batch(keep_indices=keep_indices)
|
||||
self.preempt_list.extend(preemptible_reqs)
|
||||
return True
|
||||
|
||||
@@ -243,6 +243,13 @@ class Scheduler(
|
||||
self.pp_size = server_args.pp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
||||
self.schedule_low_priority_values_first = (
|
||||
server_args.schedule_low_priority_values_first
|
||||
)
|
||||
self.priority_scheduling_preemption_threshold = (
|
||||
server_args.priority_scheduling_preemption_threshold
|
||||
)
|
||||
self.enable_lora = server_args.enable_lora
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
@@ -487,7 +494,12 @@ class Scheduler(
|
||||
self.schedule_policy,
|
||||
self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
self.enable_priority_scheduling,
|
||||
self.schedule_low_priority_values_first,
|
||||
)
|
||||
# Enable preemption for priority scheduling.
|
||||
self.try_preemption = self.enable_priority_scheduling
|
||||
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
), "Invalid schedule_conservativeness"
|
||||
@@ -1150,20 +1162,6 @@ class Scheduler(
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
# If it is a work request, accept or reject the request based on the request queue size.
|
||||
if is_work_request(recv_req):
|
||||
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
||||
abort_req = AbortReq(
|
||||
recv_req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "The request queue is full.",
|
||||
},
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
continue
|
||||
|
||||
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
||||
if isinstance(recv_req, MultiTokenizerWrapper):
|
||||
worker_id = recv_req.worker_id
|
||||
@@ -1233,6 +1231,7 @@ class Scheduler(
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
data_parallel_rank=recv_req.data_parallel_rank,
|
||||
vocab_size=self.model_config.vocab_size,
|
||||
priority=recv_req.priority,
|
||||
metrics_collector=(
|
||||
self.metrics_collector if self.enable_metrics else None
|
||||
),
|
||||
@@ -1382,6 +1381,9 @@ class Scheduler(
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
self._set_or_validate_priority(req)
|
||||
if self._abort_on_queued_limit(req):
|
||||
return
|
||||
self._prefetch_kvcache(req)
|
||||
self.waiting_queue.append(req)
|
||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||
@@ -1408,7 +1410,70 @@ class Scheduler(
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
for req in reqs:
|
||||
self._set_or_validate_priority(req)
|
||||
if not self._abort_on_queued_limit(req):
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _set_or_validate_priority(self, req: Req):
|
||||
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||
if self.enable_priority_scheduling and req.priority is None:
|
||||
if self.schedule_low_priority_values_first:
|
||||
req.priority = sys.maxsize
|
||||
else:
|
||||
req.priority = -sys.maxsize - 1
|
||||
elif not self.enable_priority_scheduling and req.priority is not None:
|
||||
abort_req = AbortReq(
|
||||
req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
||||
},
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
|
||||
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
||||
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
||||
if (
|
||||
self.max_queued_requests is None
|
||||
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
||||
):
|
||||
return False
|
||||
|
||||
# Reject the incoming request by default.
|
||||
req_to_abort = recv_req
|
||||
message = "The request queue is full."
|
||||
if self.enable_priority_scheduling:
|
||||
# With priority scheduling, consider aboritng an existing request based on the priority.
|
||||
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
||||
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
||||
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
||||
direction = 1 if self.schedule_low_priority_values_first else -1
|
||||
key_fn = lambda item: (
|
||||
direction * item[1].priority,
|
||||
item[1].queue_time_start,
|
||||
)
|
||||
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
||||
abort_existing_req = (
|
||||
direction * recv_req.priority < direction * candidate_req.priority
|
||||
)
|
||||
if abort_existing_req:
|
||||
self.waiting_queue.pop(idx)
|
||||
req_to_abort = candidate_req
|
||||
message = "The request is aborted by a higher priority request."
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(
|
||||
req_to_abort.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
)
|
||||
return req_to_abort.rid == recv_req.rid
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
@@ -1420,6 +1485,7 @@ class Scheduler(
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
token_type_ids=recv_req.token_type_ids,
|
||||
priority=recv_req.priority,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1680,6 +1746,10 @@ class Scheduler(
|
||||
if self.grammar_queue:
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
if self.try_preemption:
|
||||
# Reset batch_is_full to try preemption with a prefill adder.
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
||||
@@ -1692,7 +1762,11 @@ class Scheduler(
|
||||
# as the space for the chunked request has just been released.
|
||||
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
||||
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
||||
if (
|
||||
self.get_num_allocatable_reqs(running_bs) <= 0
|
||||
and not self.chunked_req
|
||||
and not self.try_preemption
|
||||
):
|
||||
self.running_batch.batch_is_full = True
|
||||
return None
|
||||
|
||||
@@ -1712,6 +1786,7 @@ class Scheduler(
|
||||
self.max_prefill_tokens,
|
||||
self.chunked_prefill_size,
|
||||
running_bs if self.is_mixed_chunk else 0,
|
||||
self.priority_scheduling_preemption_threshold,
|
||||
)
|
||||
|
||||
if self.chunked_req is not None:
|
||||
@@ -1732,15 +1807,19 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
||||
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
||||
# so we need to check if the available size for the actual available size.
|
||||
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
||||
self.running_batch.batch_is_full = True
|
||||
|
||||
if self.running_batch.batch_is_full:
|
||||
if not self.try_preemption:
|
||||
break
|
||||
if not adder.preempt_to_schedule(req, self.server_args):
|
||||
break
|
||||
|
||||
if self.enable_hicache_storage:
|
||||
@@ -1777,6 +1856,8 @@ class Scheduler(
|
||||
self.waiting_queue = [
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
if adder.preempt_list:
|
||||
self._extend_requests_to_queue(adder.preempt_list)
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
|
||||
@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
custom_logit_processor=obj.custom_logit_processor,
|
||||
return_hidden_states=obj.return_hidden_states,
|
||||
data_parallel_rank=obj.data_parallel_rank,
|
||||
priority=obj.priority,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
mm_inputs,
|
||||
token_type_ids,
|
||||
sampling_params,
|
||||
priority=obj.priority,
|
||||
)
|
||||
|
||||
return tokenized_obj
|
||||
|
||||
@@ -149,8 +149,8 @@ class TpModelWorker:
|
||||
assert self.max_running_requests > 0, "max_running_request is zero"
|
||||
self.max_queued_requests = server_args.max_queued_requests
|
||||
assert (
|
||||
self.max_queued_requests > 0
|
||||
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
||||
self.max_queued_requests is None or self.max_queued_requests >= 1
|
||||
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
|
||||
self.max_req_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
|
||||
@@ -172,11 +172,14 @@ class ServerArgs:
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_queued_requests: Optional[int] = sys.maxsize
|
||||
max_queued_requests: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "fcfs"
|
||||
enable_priority_scheduling: bool = False
|
||||
schedule_low_priority_values_first: bool = False
|
||||
priority_scheduling_preemption_threshold: int = 10
|
||||
schedule_conservativeness: float = 1.0
|
||||
page_size: Optional[int] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
@@ -1166,6 +1169,24 @@ class ServerArgs:
|
||||
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
|
||||
help="The scheduling policy of the requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-priority-scheduling",
|
||||
action="store_true",
|
||||
default=ServerArgs.enable_priority_scheduling,
|
||||
help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-low-priority-values-first",
|
||||
action="store_true",
|
||||
default=ServerArgs.schedule_low_priority_values_first,
|
||||
help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--priority-scheduling-preemption-threshold",
|
||||
type=int,
|
||||
default=ServerArgs.priority_scheduling_preemption_threshold,
|
||||
help="Minimum difference in priorities for an incoming request to have to preempt running request(s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-conservativeness",
|
||||
type=float,
|
||||
@@ -2455,6 +2476,13 @@ class ServerArgs:
|
||||
"--generation-tokens-buckets", self.generation_tokens_buckets
|
||||
)
|
||||
|
||||
# Check scheduling policy
|
||||
if self.enable_priority_scheduling:
|
||||
assert self.schedule_policy in [
|
||||
"fcfs",
|
||||
"lof",
|
||||
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
||||
|
||||
def check_lora_server_args(self):
|
||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests(
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def send_concurrent_generate_requests_with_custom_params(
|
||||
base_url: str,
|
||||
custom_params: List[dict[str, Any]],
|
||||
) -> Tuple[int, Any]:
|
||||
"""Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""
|
||||
|
||||
base_payload = {
|
||||
"text": """
|
||||
System: You are a helpful assistant.
|
||||
User: What is the capital of France?
|
||||
Assistant: The capital of France is
|
||||
""",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 50,
|
||||
},
|
||||
}
|
||||
|
||||
async def async_generate_with_priority(req):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{base_url}/generate",
|
||||
json=req,
|
||||
) as response:
|
||||
resp_json = await response.json()
|
||||
return (response.status, resp_json)
|
||||
|
||||
tasks = []
|
||||
for c in custom_params:
|
||||
req = base_payload.copy()
|
||||
req.update(c)
|
||||
tasks.append(asyncio.create_task(async_generate_with_priority(req)))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class CustomTestCase(unittest.TestCase):
|
||||
def _callTestMethod(self, method):
|
||||
max_retry = int(
|
||||
|
||||
Reference in New Issue
Block a user