feat: add priority based scheduling with priority based request acceptance and preemption (#8746)
This commit is contained in:
@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# For request id
|
# For request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# For customer metric labels
|
# For customer metric labels
|
||||||
customer_labels: Optional[Dict[str, str]] = None
|
customer_labels: Optional[Dict[str, str]] = None
|
||||||
@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# For request id
|
# For request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# For PD disaggregation
|
# For PD disaggregation
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel):
|
|||||||
|
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingObject(BaseModel):
|
class EmbeddingObject(BaseModel):
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
bootstrap_room=request.bootstrap_room,
|
bootstrap_room=request.bootstrap_room,
|
||||||
return_hidden_states=request.return_hidden_states,
|
return_hidden_states=request.return_hidden_states,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
|
priority=request.priority,
|
||||||
customer_labels=customer_labels,
|
customer_labels=customer_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
bootstrap_room=request.bootstrap_room,
|
bootstrap_room=request.bootstrap_room,
|
||||||
return_hidden_states=request.return_hidden_states,
|
return_hidden_states=request.return_hidden_states,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
|
priority=request.priority,
|
||||||
customer_labels=customer_labels,
|
customer_labels=customer_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
adapted_request = EmbeddingReqInput(
|
adapted_request = EmbeddingReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
|
|||||||
@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput:
|
|||||||
token_ids_logprob: List[int]
|
token_ids_logprob: List[int]
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
@@ -656,6 +657,8 @@ class EmbeddingReqInput:
|
|||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
# For cross-encoder requests
|
# For cross-encoder requests
|
||||||
is_cross_encoder_request: bool = False
|
is_cross_encoder_request: bool = False
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# For background responses (OpenAI responses API)
|
# For background responses (OpenAI responses API)
|
||||||
background: bool = False
|
background: bool = False
|
||||||
@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput:
|
|||||||
data_parallel_rank: Optional[int] = None
|
data_parallel_rank: Optional[int] = None
|
||||||
# For dp balance
|
# For dp balance
|
||||||
dp_balance_id: int = -1
|
dp_balance_id: int = -1
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -453,6 +453,7 @@ class Req:
|
|||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
vocab_size: Optional[int] = None,
|
vocab_size: Optional[int] = None,
|
||||||
|
priority: Optional[int] = None,
|
||||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
@@ -504,6 +505,7 @@ class Req:
|
|||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.eos_token_ids = eos_token_ids
|
self.eos_token_ids = eos_token_ids
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
# ----- | --------- read_ids -------|
|
# ----- | --------- read_ids -------|
|
||||||
@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
idx = sorted_indices.pop()
|
idx = sorted_indices.pop()
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
retracted_reqs.append(req)
|
retracted_reqs.append(req)
|
||||||
|
self.release_req(idx, len(sorted_indices), server_args)
|
||||||
if server_args.disaggregation_mode == "decode":
|
|
||||||
req.offload_kv_cache(
|
|
||||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(self.tree_cache, ChunkCache):
|
|
||||||
# ChunkCache does not have eviction
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
|
||||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
|
||||||
]
|
|
||||||
self.token_to_kv_pool_allocator.free(token_indices)
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
else:
|
|
||||||
# TODO: apply more fine-grained retraction
|
|
||||||
last_uncached_pos = (
|
|
||||||
len(req.prefix_indices) // server_args.page_size
|
|
||||||
) * server_args.page_size
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
|
||||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
|
||||||
]
|
|
||||||
self.token_to_kv_pool_allocator.free(token_indices)
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
|
|
||||||
# release the last node
|
|
||||||
if self.is_hybrid:
|
|
||||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
|
||||||
else:
|
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
|
||||||
|
|
||||||
req.reset_for_retract()
|
|
||||||
|
|
||||||
if len(retracted_reqs) == 0:
|
if len(retracted_reqs) == 0:
|
||||||
# Corner case: only one request left
|
# Corner case: only one request left
|
||||||
@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
return retracted_reqs, new_estimate_ratio
|
return retracted_reqs, new_estimate_ratio
|
||||||
|
|
||||||
|
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||||
|
req = self.reqs[idx]
|
||||||
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
|
|
||||||
|
if server_args.disaggregation_mode == "decode":
|
||||||
|
req.offload_kv_cache(
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||||
|
)
|
||||||
|
if isinstance(self.tree_cache, ChunkCache):
|
||||||
|
# ChunkCache does not have eviction
|
||||||
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||||
|
]
|
||||||
|
self.token_to_kv_pool_allocator.free(token_indices)
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
else:
|
||||||
|
# TODO: apply more fine-grained retraction
|
||||||
|
last_uncached_pos = (
|
||||||
|
len(req.prefix_indices) // server_args.page_size
|
||||||
|
) * server_args.page_size
|
||||||
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||||
|
]
|
||||||
|
self.token_to_kv_pool_allocator.free(token_indices)
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
|
# release the last node
|
||||||
|
if self.is_hybrid:
|
||||||
|
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||||
|
else:
|
||||||
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||||
|
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
||||||
|
self._evict_tree_cache_if_needed(num_tokens)
|
||||||
|
|
||||||
|
req.reset_for_retract()
|
||||||
|
|
||||||
def prepare_encoder_info_decode(self):
|
def prepare_encoder_info_decode(self):
|
||||||
# Reset the encoder cached status
|
# Reset the encoder cached status
|
||||||
self.encoder_cached = [True] * len(self.reqs)
|
self.encoder_cached = [True] * len(self.reqs)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|||||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
@@ -82,10 +83,14 @@ class SchedulePolicy:
|
|||||||
policy: str,
|
policy: str,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
enable_hierarchical_cache: bool,
|
enable_hierarchical_cache: bool,
|
||||||
|
enable_priority_scheduling: bool,
|
||||||
|
schedule_low_priority_values_first: bool,
|
||||||
):
|
):
|
||||||
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
self.enable_hierarchical_cache = enable_hierarchical_cache
|
self.enable_hierarchical_cache = enable_hierarchical_cache
|
||||||
|
self.enable_priority_scheduling = enable_priority_scheduling
|
||||||
|
self.schedule_low_priority_values_first = schedule_low_priority_values_first
|
||||||
|
|
||||||
# It is used to find the matching prefix for in-batch prefix caching.
|
# It is used to find the matching prefix for in-batch prefix caching.
|
||||||
self.waiting_queue_radix_tree = RadixCache(
|
self.waiting_queue_radix_tree = RadixCache(
|
||||||
@@ -97,7 +102,10 @@ class SchedulePolicy:
|
|||||||
|
|
||||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||||
if self.policy == CacheAgnosticPolicy.FCFS:
|
if self.policy == CacheAgnosticPolicy.FCFS:
|
||||||
# A shortcut for FCFS
|
if self.enable_priority_scheduling:
|
||||||
|
SchedulePolicy._sort_by_priority_and_fcfs(
|
||||||
|
waiting_queue, self.schedule_low_priority_values_first
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
policy = self._determine_active_policy(waiting_queue)
|
policy = self._determine_active_policy(waiting_queue)
|
||||||
@@ -120,12 +128,15 @@ class SchedulePolicy:
|
|||||||
if policy == CacheAgnosticPolicy.FCFS:
|
if policy == CacheAgnosticPolicy.FCFS:
|
||||||
pass
|
pass
|
||||||
elif policy == CacheAgnosticPolicy.LOF:
|
elif policy == CacheAgnosticPolicy.LOF:
|
||||||
SchedulePolicy._sort_by_longest_output(waiting_queue)
|
SchedulePolicy._sort_by_longest_output(
|
||||||
|
waiting_queue,
|
||||||
|
self.enable_priority_scheduling,
|
||||||
|
self.schedule_low_priority_values_first,
|
||||||
|
)
|
||||||
elif policy == CacheAgnosticPolicy.RANDOM:
|
elif policy == CacheAgnosticPolicy.RANDOM:
|
||||||
SchedulePolicy._sort_randomly(waiting_queue)
|
SchedulePolicy._sort_randomly(waiting_queue)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
||||||
|
|
||||||
return prefix_computed
|
return prefix_computed
|
||||||
|
|
||||||
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
||||||
@@ -231,15 +242,39 @@ class SchedulePolicy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
|
def _sort_by_longest_output(
|
||||||
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
|
waiting_queue: List[Req],
|
||||||
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
enable_priority_scheduling: bool,
|
||||||
|
schedule_low_priority_values_first: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
|
||||||
|
if enable_priority_scheduling:
|
||||||
|
if schedule_low_priority_values_first:
|
||||||
|
waiting_queue.sort(
|
||||||
|
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
waiting_queue.sort(
|
||||||
|
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
||||||
"""Shuffles the waiting queue randomly."""
|
"""Shuffles the waiting queue randomly."""
|
||||||
random.shuffle(waiting_queue)
|
random.shuffle(waiting_queue)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sort_by_priority_and_fcfs(
|
||||||
|
waiting_queue: List[Req], schedule_low_priority_values_first: bool
|
||||||
|
) -> None:
|
||||||
|
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
||||||
|
if schedule_low_priority_values_first:
|
||||||
|
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
|
||||||
|
else:
|
||||||
|
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
||||||
for child in cur_node.children.values():
|
for child in cur_node.children.values():
|
||||||
@@ -279,6 +314,7 @@ class PrefillAdder:
|
|||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
rem_chunk_tokens: Optional[int],
|
rem_chunk_tokens: Optional[int],
|
||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
|
priority_scheduling_preemption_threshold: int = 0,
|
||||||
):
|
):
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
@@ -295,6 +331,7 @@ class PrefillAdder:
|
|||||||
|
|
||||||
self.req_states = None
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
|
self.preempt_list = []
|
||||||
self.new_chunked_req = None
|
self.new_chunked_req = None
|
||||||
self.log_hit_tokens = 0
|
self.log_hit_tokens = 0
|
||||||
# TODO(lsyin): report the real input tokens excluding page alignment
|
# TODO(lsyin): report the real input tokens excluding page alignment
|
||||||
@@ -303,11 +340,7 @@ class PrefillAdder:
|
|||||||
if running_batch is not None:
|
if running_batch is not None:
|
||||||
self.rem_total_token_offset += sum(
|
self.rem_total_token_offset += sum(
|
||||||
[
|
[
|
||||||
min(
|
self._get_running_request_total_token_offset(r)
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
|
||||||
CLIP_MAX_NEW_TOKENS,
|
|
||||||
)
|
|
||||||
* self.new_token_ratio
|
|
||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -316,6 +349,19 @@ class PrefillAdder:
|
|||||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.priority_scheduling_preemption_threshold = (
|
||||||
|
priority_scheduling_preemption_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
||||||
|
return (
|
||||||
|
min(
|
||||||
|
(req.sampling_params.max_new_tokens - len(req.output_ids)),
|
||||||
|
CLIP_MAX_NEW_TOKENS,
|
||||||
|
)
|
||||||
|
* self.new_token_ratio
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rem_total_tokens(self):
|
def rem_total_tokens(self):
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
@@ -568,3 +614,61 @@ class PrefillAdder:
|
|||||||
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
|
|
||||||
|
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
|
||||||
|
"""
|
||||||
|
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
|
||||||
|
Returns True if preemption was committed, and the new request can be scheduled.
|
||||||
|
"""
|
||||||
|
# Iterate running requests to find preemptible requests
|
||||||
|
if server_args.schedule_low_priority_values_first:
|
||||||
|
sorted_running_reqs = sorted(
|
||||||
|
self.running_batch.reqs,
|
||||||
|
key=lambda x: (-x.priority, -x.queue_time_start),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sorted_running_reqs = sorted(
|
||||||
|
self.running_batch.reqs,
|
||||||
|
key=lambda x: (x.priority, -x.queue_time_start),
|
||||||
|
)
|
||||||
|
preemptible_reqs = []
|
||||||
|
min_tokens_to_remove = (
|
||||||
|
req.extend_input_len
|
||||||
|
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||||
|
- self.rem_total_tokens
|
||||||
|
)
|
||||||
|
for running_req in sorted_running_reqs:
|
||||||
|
if running_req in self.preempt_list:
|
||||||
|
continue
|
||||||
|
# Priority difference needs to meet the threshold to be preemptible.
|
||||||
|
priority_diff = req.priority - running_req.priority
|
||||||
|
if server_args.schedule_low_priority_values_first:
|
||||||
|
priority_diff *= -1
|
||||||
|
if priority_diff > self.priority_scheduling_preemption_threshold:
|
||||||
|
preemptible_reqs.append(running_req)
|
||||||
|
min_tokens_to_remove -= self._get_running_request_total_token_offset(
|
||||||
|
running_req
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check max token count limit can be met
|
||||||
|
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Preempt running requests. Release allocated resources for immediate usage.
|
||||||
|
preemptible_reqs = set(preemptible_reqs)
|
||||||
|
keep_indices = []
|
||||||
|
release_counter = 0
|
||||||
|
for i, running_req in enumerate(self.running_batch.reqs):
|
||||||
|
if running_req in preemptible_reqs:
|
||||||
|
self.rem_total_token_offset -= (
|
||||||
|
self._get_running_request_total_token_offset(req)
|
||||||
|
)
|
||||||
|
release_counter += 1
|
||||||
|
self.running_batch.release_req(
|
||||||
|
i, len(self.running_batch.reqs) - release_counter, server_args
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
keep_indices.append(i)
|
||||||
|
self.running_batch.filter_batch(keep_indices=keep_indices)
|
||||||
|
self.preempt_list.extend(preemptible_reqs)
|
||||||
|
return True
|
||||||
|
|||||||
@@ -243,6 +243,13 @@ class Scheduler(
|
|||||||
self.pp_size = server_args.pp_size
|
self.pp_size = server_args.pp_size
|
||||||
self.dp_size = server_args.dp_size
|
self.dp_size = server_args.dp_size
|
||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
|
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
||||||
|
self.schedule_low_priority_values_first = (
|
||||||
|
server_args.schedule_low_priority_values_first
|
||||||
|
)
|
||||||
|
self.priority_scheduling_preemption_threshold = (
|
||||||
|
server_args.priority_scheduling_preemption_threshold
|
||||||
|
)
|
||||||
self.enable_lora = server_args.enable_lora
|
self.enable_lora = server_args.enable_lora
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
@@ -487,7 +494,12 @@ class Scheduler(
|
|||||||
self.schedule_policy,
|
self.schedule_policy,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.enable_hierarchical_cache,
|
self.enable_hierarchical_cache,
|
||||||
|
self.enable_priority_scheduling,
|
||||||
|
self.schedule_low_priority_values_first,
|
||||||
)
|
)
|
||||||
|
# Enable preemption for priority scheduling.
|
||||||
|
self.try_preemption = self.enable_priority_scheduling
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
server_args.schedule_conservativeness >= 0
|
server_args.schedule_conservativeness >= 0
|
||||||
), "Invalid schedule_conservativeness"
|
), "Invalid schedule_conservativeness"
|
||||||
@@ -1150,20 +1162,6 @@ class Scheduler(
|
|||||||
self.return_health_check_ct += 1
|
self.return_health_check_ct += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If it is a work request, accept or reject the request based on the request queue size.
|
|
||||||
if is_work_request(recv_req):
|
|
||||||
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
|
||||||
abort_req = AbortReq(
|
|
||||||
recv_req.rid,
|
|
||||||
finished_reason={
|
|
||||||
"type": "abort",
|
|
||||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
|
||||||
"message": "The request queue is full.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
||||||
if isinstance(recv_req, MultiTokenizerWrapper):
|
if isinstance(recv_req, MultiTokenizerWrapper):
|
||||||
worker_id = recv_req.worker_id
|
worker_id = recv_req.worker_id
|
||||||
@@ -1233,6 +1231,7 @@ class Scheduler(
|
|||||||
bootstrap_room=recv_req.bootstrap_room,
|
bootstrap_room=recv_req.bootstrap_room,
|
||||||
data_parallel_rank=recv_req.data_parallel_rank,
|
data_parallel_rank=recv_req.data_parallel_rank,
|
||||||
vocab_size=self.model_config.vocab_size,
|
vocab_size=self.model_config.vocab_size,
|
||||||
|
priority=recv_req.priority,
|
||||||
metrics_collector=(
|
metrics_collector=(
|
||||||
self.metrics_collector if self.enable_metrics else None
|
self.metrics_collector if self.enable_metrics else None
|
||||||
),
|
),
|
||||||
@@ -1382,6 +1381,9 @@ class Scheduler(
|
|||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.disagg_decode_prealloc_queue.add(req)
|
self.disagg_decode_prealloc_queue.add(req)
|
||||||
else:
|
else:
|
||||||
|
self._set_or_validate_priority(req)
|
||||||
|
if self._abort_on_queued_limit(req):
|
||||||
|
return
|
||||||
self._prefetch_kvcache(req)
|
self._prefetch_kvcache(req)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||||
@@ -1408,7 +1410,70 @@ class Scheduler(
|
|||||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||||
else:
|
else:
|
||||||
self.waiting_queue.extend(reqs)
|
for req in reqs:
|
||||||
|
self._set_or_validate_priority(req)
|
||||||
|
if not self._abort_on_queued_limit(req):
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
|
def _set_or_validate_priority(self, req: Req):
|
||||||
|
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||||
|
if self.enable_priority_scheduling and req.priority is None:
|
||||||
|
if self.schedule_low_priority_values_first:
|
||||||
|
req.priority = sys.maxsize
|
||||||
|
else:
|
||||||
|
req.priority = -sys.maxsize - 1
|
||||||
|
elif not self.enable_priority_scheduling and req.priority is not None:
|
||||||
|
abort_req = AbortReq(
|
||||||
|
req.rid,
|
||||||
|
finished_reason={
|
||||||
|
"type": "abort",
|
||||||
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
|
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||||
|
|
||||||
|
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
||||||
|
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
||||||
|
if (
|
||||||
|
self.max_queued_requests is None
|
||||||
|
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Reject the incoming request by default.
|
||||||
|
req_to_abort = recv_req
|
||||||
|
message = "The request queue is full."
|
||||||
|
if self.enable_priority_scheduling:
|
||||||
|
# With priority scheduling, consider aboritng an existing request based on the priority.
|
||||||
|
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
||||||
|
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
||||||
|
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
||||||
|
direction = 1 if self.schedule_low_priority_values_first else -1
|
||||||
|
key_fn = lambda item: (
|
||||||
|
direction * item[1].priority,
|
||||||
|
item[1].queue_time_start,
|
||||||
|
)
|
||||||
|
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
||||||
|
abort_existing_req = (
|
||||||
|
direction * recv_req.priority < direction * candidate_req.priority
|
||||||
|
)
|
||||||
|
if abort_existing_req:
|
||||||
|
self.waiting_queue.pop(idx)
|
||||||
|
req_to_abort = candidate_req
|
||||||
|
message = "The request is aborted by a higher priority request."
|
||||||
|
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
AbortReq(
|
||||||
|
req_to_abort.rid,
|
||||||
|
finished_reason={
|
||||||
|
"type": "abort",
|
||||||
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return req_to_abort.rid == recv_req.rid
|
||||||
|
|
||||||
def handle_embedding_request(
|
def handle_embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -1420,6 +1485,7 @@ class Scheduler(
|
|||||||
recv_req.input_ids,
|
recv_req.input_ids,
|
||||||
recv_req.sampling_params,
|
recv_req.sampling_params,
|
||||||
token_type_ids=recv_req.token_type_ids,
|
token_type_ids=recv_req.token_type_ids,
|
||||||
|
priority=recv_req.priority,
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
@@ -1680,6 +1746,10 @@ class Scheduler(
|
|||||||
if self.grammar_queue:
|
if self.grammar_queue:
|
||||||
self.move_ready_grammar_requests()
|
self.move_ready_grammar_requests()
|
||||||
|
|
||||||
|
if self.try_preemption:
|
||||||
|
# Reset batch_is_full to try preemption with a prefill adder.
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
||||||
@@ -1692,7 +1762,11 @@ class Scheduler(
|
|||||||
# as the space for the chunked request has just been released.
|
# as the space for the chunked request has just been released.
|
||||||
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||||
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
||||||
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
if (
|
||||||
|
self.get_num_allocatable_reqs(running_bs) <= 0
|
||||||
|
and not self.chunked_req
|
||||||
|
and not self.try_preemption
|
||||||
|
):
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1712,6 +1786,7 @@ class Scheduler(
|
|||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.chunked_prefill_size,
|
self.chunked_prefill_size,
|
||||||
running_bs if self.is_mixed_chunk else 0,
|
running_bs if self.is_mixed_chunk else 0,
|
||||||
|
self.priority_scheduling_preemption_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.chunked_req is not None:
|
if self.chunked_req is not None:
|
||||||
@@ -1732,15 +1807,19 @@ class Scheduler(
|
|||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
||||||
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
||||||
# so we need to check if the available size for the actual available size.
|
# so we need to check if the available size for the actual available size.
|
||||||
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
|
|
||||||
|
if self.running_batch.batch_is_full:
|
||||||
|
if not self.try_preemption:
|
||||||
|
break
|
||||||
|
if not adder.preempt_to_schedule(req, self.server_args):
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.enable_hicache_storage:
|
if self.enable_hicache_storage:
|
||||||
@@ -1777,6 +1856,8 @@ class Scheduler(
|
|||||||
self.waiting_queue = [
|
self.waiting_queue = [
|
||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
]
|
]
|
||||||
|
if adder.preempt_list:
|
||||||
|
self._extend_requests_to_queue(adder.preempt_list)
|
||||||
|
|
||||||
if adder.new_chunked_req is not None:
|
if adder.new_chunked_req is not None:
|
||||||
assert self.chunked_req is None
|
assert self.chunked_req is None
|
||||||
|
|||||||
@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
custom_logit_processor=obj.custom_logit_processor,
|
custom_logit_processor=obj.custom_logit_processor,
|
||||||
return_hidden_states=obj.return_hidden_states,
|
return_hidden_states=obj.return_hidden_states,
|
||||||
data_parallel_rank=obj.data_parallel_rank,
|
data_parallel_rank=obj.data_parallel_rank,
|
||||||
|
priority=obj.priority,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
mm_inputs,
|
mm_inputs,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
|
priority=obj.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
return tokenized_obj
|
return tokenized_obj
|
||||||
|
|||||||
@@ -149,8 +149,8 @@ class TpModelWorker:
|
|||||||
assert self.max_running_requests > 0, "max_running_request is zero"
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
||||||
self.max_queued_requests = server_args.max_queued_requests
|
self.max_queued_requests = server_args.max_queued_requests
|
||||||
assert (
|
assert (
|
||||||
self.max_queued_requests > 0
|
self.max_queued_requests is None or self.max_queued_requests >= 1
|
||||||
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
|
||||||
self.max_req_len = min(
|
self.max_req_len = min(
|
||||||
self.model_config.context_len - 1,
|
self.model_config.context_len - 1,
|
||||||
self.max_total_num_tokens - 1,
|
self.max_total_num_tokens - 1,
|
||||||
|
|||||||
@@ -172,11 +172,14 @@ class ServerArgs:
|
|||||||
# Memory and scheduling
|
# Memory and scheduling
|
||||||
mem_fraction_static: Optional[float] = None
|
mem_fraction_static: Optional[float] = None
|
||||||
max_running_requests: Optional[int] = None
|
max_running_requests: Optional[int] = None
|
||||||
max_queued_requests: Optional[int] = sys.maxsize
|
max_queued_requests: Optional[int] = None
|
||||||
max_total_tokens: Optional[int] = None
|
max_total_tokens: Optional[int] = None
|
||||||
chunked_prefill_size: Optional[int] = None
|
chunked_prefill_size: Optional[int] = None
|
||||||
max_prefill_tokens: int = 16384
|
max_prefill_tokens: int = 16384
|
||||||
schedule_policy: str = "fcfs"
|
schedule_policy: str = "fcfs"
|
||||||
|
enable_priority_scheduling: bool = False
|
||||||
|
schedule_low_priority_values_first: bool = False
|
||||||
|
priority_scheduling_preemption_threshold: int = 10
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
page_size: Optional[int] = None
|
page_size: Optional[int] = None
|
||||||
hybrid_kvcache_ratio: Optional[float] = None
|
hybrid_kvcache_ratio: Optional[float] = None
|
||||||
@@ -1166,6 +1169,24 @@ class ServerArgs:
|
|||||||
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
|
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
|
||||||
help="The scheduling policy of the requests.",
|
help="The scheduling policy of the requests.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-priority-scheduling",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.enable_priority_scheduling,
|
||||||
|
help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--schedule-low-priority-values-first",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.schedule_low_priority_values_first,
|
||||||
|
help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--priority-scheduling-preemption-threshold",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.priority_scheduling_preemption_threshold,
|
||||||
|
help="Minimum difference in priorities for an incoming request to have to preempt running request(s).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--schedule-conservativeness",
|
"--schedule-conservativeness",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -2455,6 +2476,13 @@ class ServerArgs:
|
|||||||
"--generation-tokens-buckets", self.generation_tokens_buckets
|
"--generation-tokens-buckets", self.generation_tokens_buckets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check scheduling policy
|
||||||
|
if self.enable_priority_scheduling:
|
||||||
|
assert self.schedule_policy in [
|
||||||
|
"fcfs",
|
||||||
|
"lof",
|
||||||
|
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
||||||
|
|
||||||
def check_lora_server_args(self):
|
def check_lora_server_args(self):
|
||||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests(
|
|||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_concurrent_generate_requests_with_custom_params(
|
||||||
|
base_url: str,
|
||||||
|
custom_params: List[dict[str, Any]],
|
||||||
|
) -> Tuple[int, Any]:
|
||||||
|
"""Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""
|
||||||
|
|
||||||
|
base_payload = {
|
||||||
|
"text": """
|
||||||
|
System: You are a helpful assistant.
|
||||||
|
User: What is the capital of France?
|
||||||
|
Assistant: The capital of France is
|
||||||
|
""",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def async_generate_with_priority(req):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json=req,
|
||||||
|
) as response:
|
||||||
|
resp_json = await response.json()
|
||||||
|
return (response.status, resp_json)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for c in custom_params:
|
||||||
|
req = base_payload.copy()
|
||||||
|
req.update(c)
|
||||||
|
tasks.append(asyncio.create_task(async_generate_with_priority(req)))
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
class CustomTestCase(unittest.TestCase):
|
class CustomTestCase(unittest.TestCase):
|
||||||
def _callTestMethod(self, method):
|
def _callTestMethod(self, method):
|
||||||
max_retry = int(
|
max_retry = int(
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ suites = {
|
|||||||
TestFile("test_original_logprobs.py", 200),
|
TestFile("test_original_logprobs.py", 200),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
|
TestFile("test_priority_scheduling.py", 100),
|
||||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||||
TestFile("test_radix_attention.py", 105),
|
TestFile("test_radix_attention.py", 105),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
|
|||||||
339
test/srt/test_priority_scheduling.py
Normal file
339
test/srt/test_priority_scheduling.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
STDERR_FILENAME,
|
||||||
|
STDOUT_FILENAME,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
send_concurrent_generate_requests_with_custom_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPriorityScheduling(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||||
|
cls.stderr = open(STDERR_FILENAME, "w")
|
||||||
|
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--max-running-requests", # Enforce max request concurrency is 1
|
||||||
|
"1",
|
||||||
|
"--max-queued-requests", # Enforce max queued request number is 3
|
||||||
|
"3",
|
||||||
|
"--enable-priority-scheduling", # Enable priority scheduling
|
||||||
|
),
|
||||||
|
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
_verify_max_running_requests_and_max_queued_request_validation(1, 3)
|
||||||
|
cls.stdout.close()
|
||||||
|
cls.stderr.close()
|
||||||
|
os.remove(STDOUT_FILENAME)
|
||||||
|
os.remove(STDERR_FILENAME)
|
||||||
|
|
||||||
|
def test_priority_scheduling_request_ordering_validation(self):
|
||||||
|
"""Verify pending requests are ordered by priority and received timestamp."""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 0,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # starts being processed first
|
||||||
|
{"priority": 1}, # third
|
||||||
|
{"priority": 1}, # fourth
|
||||||
|
{"priority": 2}, # second
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
e2e_latencies = []
|
||||||
|
_verify_genereate_responses(
|
||||||
|
responses, expected_status_and_error_messages, e2e_latencies
|
||||||
|
)
|
||||||
|
assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2]
|
||||||
|
|
||||||
|
def test_priority_scheduling_existing_requests_abortion_validation(self):
|
||||||
|
"""Verify lower priority requests are aborted when incoming requests have higher priority"""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 1,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # starts being processed first and holds the running queue capacity
|
||||||
|
{"priority": 2}, # aborted by request 5
|
||||||
|
{"priority": 3}, # aborted by request 6
|
||||||
|
{"priority": 4}, # aborted by request 7
|
||||||
|
{"priority": 5}, # fourth
|
||||||
|
{"priority": 6}, # third
|
||||||
|
{"priority": 7}, # second
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(503, "The request is aborted by a higher priority request."),
|
||||||
|
(503, "The request is aborted by a higher priority request."),
|
||||||
|
(503, "The request is aborted by a higher priority request."),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
e2e_latencies = []
|
||||||
|
_verify_genereate_responses(
|
||||||
|
responses, expected_status_and_error_messages, e2e_latencies
|
||||||
|
)
|
||||||
|
assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4]
|
||||||
|
|
||||||
|
def test_priority_scheduling_incoming_request_rejection_validation(self):
|
||||||
|
"""Verify incoming requests are rejected when existing requests have higher priority"""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 7,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # starts being processed first and holds the running queue capacity
|
||||||
|
{"priority": 6}, # second
|
||||||
|
{"priority": 5}, # third
|
||||||
|
{"priority": 4}, # fourth
|
||||||
|
{"priority": 3}, # rejected
|
||||||
|
{"priority": 2}, # rejected
|
||||||
|
{"priority": 1}, # rejected
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(503, "The request queue is full."),
|
||||||
|
(503, "The request queue is full."),
|
||||||
|
(503, "The request queue is full."),
|
||||||
|
]
|
||||||
|
|
||||||
|
e2e_latencies = []
|
||||||
|
_verify_genereate_responses(
|
||||||
|
responses, expected_status_and_error_messages, e2e_latencies
|
||||||
|
)
|
||||||
|
assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3]
|
||||||
|
|
||||||
|
def test_priority_scheduling_preemption_meeting_threshold_validation(self):
|
||||||
|
"""Verify running requests are preempted by requests with priorities meeting the preemption threshold"""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 0,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # starts being processed first then preempted or pushed by later requests, and finishes last.
|
||||||
|
{
|
||||||
|
"priority": 10,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # scheduled after the third request, and finishes second.
|
||||||
|
{
|
||||||
|
"priority": 20,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # finishes first.
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
e2e_latencies = []
|
||||||
|
_verify_genereate_responses(
|
||||||
|
responses, expected_status_and_error_messages, e2e_latencies
|
||||||
|
)
|
||||||
|
|
||||||
|
assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0]
|
||||||
|
|
||||||
|
def test_priority_scheduling_preemption_below_threshold_validation(self):
|
||||||
|
"""Verify running requests are not preempted by requests with priorities below preemption threshold"""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 0,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"priority": 5,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
e2e_latencies = []
|
||||||
|
_verify_genereate_responses(
|
||||||
|
responses, expected_status_and_error_messages, e2e_latencies
|
||||||
|
)
|
||||||
|
|
||||||
|
assert e2e_latencies[0] < e2e_latencies[1]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||||
|
cls.stderr = open(STDERR_FILENAME, "w")
|
||||||
|
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--max-running-requests", # Enforce max request concurrency is 2
|
||||||
|
"2",
|
||||||
|
"--max-queued-requests", # Enforce max queued request number is 3
|
||||||
|
"3",
|
||||||
|
"--enable-priority-scheduling", # Enable priority scheduling
|
||||||
|
),
|
||||||
|
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
_verify_max_running_requests_and_max_queued_request_validation(2, 3)
|
||||||
|
cls.stdout.close()
|
||||||
|
cls.stderr.close()
|
||||||
|
os.remove(STDOUT_FILENAME)
|
||||||
|
os.remove(STDERR_FILENAME)
|
||||||
|
|
||||||
|
def test_priority_scheduling_with_multiple_running_requests_preemption(self):
|
||||||
|
"""Verify preempting a subset of running requests is safe."""
|
||||||
|
|
||||||
|
responses = asyncio.run(
|
||||||
|
send_concurrent_generate_requests_with_custom_params(
|
||||||
|
self.base_url,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"priority": 10,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # finishes first
|
||||||
|
{
|
||||||
|
"priority": 5,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # preempted by fourth request, then finishes third
|
||||||
|
{
|
||||||
|
"priority": 15,
|
||||||
|
"sampling_params": {"max_new_tokens": 10000},
|
||||||
|
}, # preempt the first request
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_status_and_error_messages = [
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
(200, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
_verify_genereate_responses(responses, expected_status_and_error_messages, [])
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_genereate_responses(
|
||||||
|
responses: Tuple[int, Any, float],
|
||||||
|
expected_code_and_error_message: Tuple[int, Any],
|
||||||
|
e2e_latencies: List[Optional[float]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify generate response results are as expected based on status code and response json object content.
|
||||||
|
In addition, collects e2e latency info to verify scheduling and processing ordering.
|
||||||
|
"""
|
||||||
|
for got, expected in zip(responses, expected_code_and_error_message):
|
||||||
|
got_status, got_json = got
|
||||||
|
expected_status, expected_err_msg = expected
|
||||||
|
|
||||||
|
# Check status code is as expected
|
||||||
|
assert got_status == expected_status
|
||||||
|
|
||||||
|
# Check error message content or fields' existence based on status code
|
||||||
|
if got_status != 200:
|
||||||
|
assert got_json["object"] == "error"
|
||||||
|
assert got_json["message"] == expected_err_msg
|
||||||
|
else:
|
||||||
|
assert "object" not in got_json
|
||||||
|
assert "message" not in got_json
|
||||||
|
|
||||||
|
# Collect e2e latencies for scheduling validation
|
||||||
|
e2e_latencies.append(
|
||||||
|
got_json["meta_info"]["e2e_latency"] if got_status == 200 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_max_running_requests_and_max_queued_request_validation(
|
||||||
|
max_running_requests: int, max_queued_requests: int
|
||||||
|
):
|
||||||
|
"""Verify running request and queued request numbers based on server logs."""
|
||||||
|
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
|
||||||
|
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
|
||||||
|
|
||||||
|
with open(STDERR_FILENAME) as lines:
|
||||||
|
for line in lines:
|
||||||
|
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
|
||||||
|
if rr_match:
|
||||||
|
assert int(rr_match.group(1)) <= max_running_requests
|
||||||
|
if qr_match:
|
||||||
|
assert int(qr_match.group(1)) <= max_queued_requests
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase):
|
|||||||
send_concurrent_generate_requests(self.base_url, num_requests=10)
|
send_concurrent_generate_requests(self.base_url, num_requests=10)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert 200 in status_codes
|
expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
|
||||||
assert 503 in status_codes
|
assert status_codes == expected_status_codes
|
||||||
assert all(status_code in [200, 503] for status_code in status_codes)
|
|
||||||
|
|
||||||
def test_max_running_requests_and_max_queued_request_validation(self):
|
def test_max_running_requests_and_max_queued_request_validation(self):
|
||||||
"""Verify running request and queued request numbers based on server logs."""
|
"""Verify running request and queued request numbers based on server logs."""
|
||||||
|
|||||||
@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
|
|||||||
|
|
||||||
def test_init_with_cache_aware_policy(self):
|
def test_init_with_cache_aware_policy(self):
|
||||||
policy = SchedulePolicy(
|
policy = SchedulePolicy(
|
||||||
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
|
policy="lpm",
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
|
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
|
||||||
|
|
||||||
def test_init_with_cache_agnostic_policy(self):
|
def test_init_with_cache_agnostic_policy(self):
|
||||||
policy = SchedulePolicy(
|
policy = SchedulePolicy(
|
||||||
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
|
policy="fcfs",
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
||||||
|
|
||||||
@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
|
|||||||
policy="invalid",
|
policy="invalid",
|
||||||
tree_cache=self.tree_cache,
|
tree_cache=self.tree_cache,
|
||||||
enable_hierarchical_cache=True,
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_init_with_disabled_cache(self):
|
def test_init_with_disabled_cache(self):
|
||||||
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
|
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
|
||||||
policy = SchedulePolicy(
|
policy = SchedulePolicy(
|
||||||
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
|
policy="lpm",
|
||||||
|
tree_cache=disabled_tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
|
||||||
|
|
||||||
@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
policy = SchedulePolicy(
|
policy = SchedulePolicy(
|
||||||
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
|
policy="fcfs",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
)
|
)
|
||||||
policy.calc_priority(waiting_queue)
|
policy.calc_priority(waiting_queue)
|
||||||
# Check if FCFS keeps the original order
|
# Check if FCFS keeps the original order
|
||||||
@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
|
|||||||
self.assertEqual(waiting_queue[1].rid, 3)
|
self.assertEqual(waiting_queue[1].rid, 3)
|
||||||
self.assertEqual(waiting_queue[2].rid, 2)
|
self.assertEqual(waiting_queue[2].rid, 2)
|
||||||
|
|
||||||
|
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
|
||||||
|
tree_cache = RadixCache(None, None, False)
|
||||||
|
|
||||||
|
waiting_queue = [
|
||||||
|
Req(1, "a b", [1, 2], SamplingParams()),
|
||||||
|
Req(3, "a b c", [1, 2, 3], SamplingParams()),
|
||||||
|
Req(2, "a", [1], SamplingParams()),
|
||||||
|
]
|
||||||
|
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
|
||||||
|
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
|
||||||
|
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
|
||||||
|
|
||||||
|
policy = SchedulePolicy(
|
||||||
|
policy="fcfs",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=True,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
|
)
|
||||||
|
policy.calc_priority(waiting_queue)
|
||||||
|
# Check if priority enabled fcfs ordering is applied.
|
||||||
|
self.assertEqual(waiting_queue[0].rid, 1)
|
||||||
|
self.assertEqual(waiting_queue[1].rid, 2)
|
||||||
|
self.assertEqual(waiting_queue[2].rid, 3)
|
||||||
|
|
||||||
|
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
tree_cache = RadixCache(None, None, False)
|
||||||
|
|
||||||
|
waiting_queue = [
|
||||||
|
Req(1, "a b", [1, 2], SamplingParams()),
|
||||||
|
Req(3, "a b c", [1, 2, 3], SamplingParams()),
|
||||||
|
Req(2, "a", [1], SamplingParams()),
|
||||||
|
]
|
||||||
|
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
|
||||||
|
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
|
||||||
|
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
|
||||||
|
|
||||||
|
policy = SchedulePolicy(
|
||||||
|
policy="fcfs",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=True,
|
||||||
|
schedule_low_priority_values_first=True,
|
||||||
|
)
|
||||||
|
policy.calc_priority(waiting_queue)
|
||||||
|
# Check if priority enabled fcfs ordering is applied.
|
||||||
|
self.assertEqual(waiting_queue[0].rid, 1)
|
||||||
|
self.assertEqual(waiting_queue[1].rid, 2)
|
||||||
|
self.assertEqual(waiting_queue[2].rid, 3)
|
||||||
|
|
||||||
|
def test_calc_priority_longest_output_first_scheduling(self):
|
||||||
|
tree_cache = RadixCache(None, None, False)
|
||||||
|
|
||||||
|
waiting_queue = [
|
||||||
|
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
|
||||||
|
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
|
||||||
|
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
|
||||||
|
]
|
||||||
|
|
||||||
|
policy = SchedulePolicy(
|
||||||
|
policy="lof",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=False,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
|
)
|
||||||
|
policy.calc_priority(waiting_queue)
|
||||||
|
# Check if priority enabled fcfs ordering is applied.
|
||||||
|
self.assertEqual(waiting_queue[0].rid, 1)
|
||||||
|
self.assertEqual(waiting_queue[1].rid, 2)
|
||||||
|
self.assertEqual(waiting_queue[2].rid, 3)
|
||||||
|
|
||||||
|
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
|
||||||
|
tree_cache = RadixCache(None, None, False)
|
||||||
|
|
||||||
|
waiting_queue = [
|
||||||
|
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
|
||||||
|
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
|
||||||
|
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
|
||||||
|
]
|
||||||
|
|
||||||
|
policy = SchedulePolicy(
|
||||||
|
policy="lof",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=True,
|
||||||
|
schedule_low_priority_values_first=False,
|
||||||
|
)
|
||||||
|
policy.calc_priority(waiting_queue)
|
||||||
|
# Check if priority enabled fcfs ordering is applied.
|
||||||
|
self.assertEqual(waiting_queue[0].rid, 1)
|
||||||
|
self.assertEqual(waiting_queue[1].rid, 2)
|
||||||
|
self.assertEqual(waiting_queue[2].rid, 3)
|
||||||
|
|
||||||
|
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
tree_cache = RadixCache(None, None, False)
|
||||||
|
|
||||||
|
waiting_queue = [
|
||||||
|
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
|
||||||
|
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
|
||||||
|
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
policy = SchedulePolicy(
|
||||||
|
policy="lof",
|
||||||
|
tree_cache=tree_cache,
|
||||||
|
enable_hierarchical_cache=True,
|
||||||
|
enable_priority_scheduling=True,
|
||||||
|
schedule_low_priority_values_first=True,
|
||||||
|
)
|
||||||
|
policy.calc_priority(waiting_queue)
|
||||||
|
# Check if priority enabled fcfs ordering is applied.
|
||||||
|
self.assertEqual(waiting_queue[0].rid, 1)
|
||||||
|
self.assertEqual(waiting_queue[1].rid, 2)
|
||||||
|
self.assertEqual(waiting_queue[2].rid, 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user