feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

This commit is contained in:
harrisonlimh
2025-09-16 17:10:10 -07:00
committed by GitHub
parent f949ad5794
commit 14fdd52740
16 changed files with 822 additions and 71 deletions

View File

@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel):
# The request id.
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
class EmbeddingObject(BaseModel):

View File

@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
priority=request.priority,
customer_labels=customer_labels,
)

View File

@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
priority=request.priority,
customer_labels=customer_labels,
)

View File

@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request = EmbeddingReqInput(
**prompt_kwargs,
rid=request.rid,
priority=request.priority,
)
return adapted_request, request

View File

@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput:
token_ids_logprob: List[int]
# Whether to stream output
stream: bool
# Whether to return hidden states
return_hidden_states: bool = False
@@ -656,6 +657,8 @@ class EmbeddingReqInput:
modalities: Optional[List[str]] = None
# For cross-encoder requests
is_cross_encoder_request: bool = False
# Priority for the request
priority: Optional[int] = None
# For background responses (OpenAI responses API)
background: bool = False
@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput:
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request
priority: Optional[int] = None
@dataclass

View File

@@ -453,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
# Input and output info
@@ -504,6 +505,7 @@ class Req:
self.stream = stream
self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size
self.priority = priority
# For incremental decoding
# ----- | --------- read_ids -------|
@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
req.reset_for_retract()
self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0:
# Corner case: only one request left
@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return retracted_reqs, new_estimate_ratio
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)

View File

@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
@@ -82,10 +83,14 @@ class SchedulePolicy:
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool,
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
self.enable_priority_scheduling = enable_priority_scheduling
self.schedule_low_priority_values_first = schedule_low_priority_values_first
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
@@ -97,7 +102,10 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
if self.enable_priority_scheduling:
SchedulePolicy._sort_by_priority_and_fcfs(
waiting_queue, self.schedule_low_priority_values_first
)
return False
policy = self._determine_active_policy(waiting_queue)
@@ -120,12 +128,15 @@ class SchedulePolicy:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
SchedulePolicy._sort_by_longest_output(
waiting_queue,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
@@ -231,15 +242,39 @@ class SchedulePolicy:
)
@staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
def _sort_by_longest_output(
waiting_queue: List[Req],
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
if enable_priority_scheduling:
if schedule_low_priority_values_first:
waiting_queue.sort(
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
@staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)
@staticmethod
def _sort_by_priority_and_fcfs(
waiting_queue: List[Req], schedule_low_priority_values_first: bool
) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values():
@@ -279,6 +314,7 @@ class PrefillAdder:
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
self.page_size = page_size
self.tree_cache = tree_cache
@@ -295,6 +331,7 @@ class PrefillAdder:
self.req_states = None
self.can_run_list = []
self.preempt_list = []
self.new_chunked_req = None
self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
@@ -303,11 +340,7 @@ class PrefillAdder:
if running_batch is not None:
self.rem_total_token_offset += sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
self._get_running_request_total_token_offset(r)
for r in running_batch.reqs
]
)
@@ -316,6 +349,19 @@ class PrefillAdder:
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
(req.sampling_params.max_new_tokens - len(req.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
)
@property
def rem_total_tokens(self):
if self.is_hybrid:
@@ -568,3 +614,61 @@ class PrefillAdder:
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
"""
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
Returns True if preemption was committed, and the new request can be scheduled.
"""
# Iterate running requests to find preemptible requests
if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start),
)
else:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start),
)
preemptible_reqs = []
min_tokens_to_remove = (
req.extend_input_len
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
- self.rem_total_tokens
)
for running_req in sorted_running_reqs:
if running_req in self.preempt_list:
continue
# Priority difference needs to meet the threshold to be preemptible.
priority_diff = req.priority - running_req.priority
if server_args.schedule_low_priority_values_first:
priority_diff *= -1
if priority_diff > self.priority_scheduling_preemption_threshold:
preemptible_reqs.append(running_req)
min_tokens_to_remove -= self._get_running_request_total_token_offset(
running_req
)
# Check max token count limit can be met
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
return False
# Preempt running requests. Release allocated resources for immediate usage.
preemptible_reqs = set(preemptible_reqs)
keep_indices = []
release_counter = 0
for i, running_req in enumerate(self.running_batch.reqs):
if running_req in preemptible_reqs:
self.rem_total_token_offset -= (
self._get_running_request_total_token_offset(req)
)
release_counter += 1
self.running_batch.release_req(
i, len(self.running_batch.reqs) - release_counter, server_args
)
else:
keep_indices.append(i)
self.running_batch.filter_batch(keep_indices=keep_indices)
self.preempt_list.extend(preemptible_reqs)
return True

View File

@@ -243,6 +243,13 @@ class Scheduler(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.enable_priority_scheduling = server_args.enable_priority_scheduling
self.schedule_low_priority_values_first = (
server_args.schedule_low_priority_values_first
)
self.priority_scheduling_preemption_threshold = (
server_args.priority_scheduling_preemption_threshold
)
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
@@ -487,7 +494,12 @@ class Scheduler(
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
# Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
@@ -1150,20 +1162,6 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
@@ -1233,6 +1231,7 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
@@ -1382,6 +1381,9 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
@@ -1408,7 +1410,70 @@ class Scheduler(
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else:
self.waiting_queue.extend(reqs)
for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None:
if self.schedule_low_priority_values_first:
req.priority = sys.maxsize
else:
req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq(
req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
if (
self.max_queued_requests is None
or len(self.waiting_queue) + 1 <= self.max_queued_requests
):
return False
# Reject the incoming request by default.
req_to_abort = recv_req
message = "The request queue is full."
if self.enable_priority_scheduling:
# With priority scheduling, consider aboritng an existing request based on the priority.
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
direction * recv_req.priority < direction * candidate_req.priority
)
if abort_existing_req:
self.waiting_queue.pop(idx)
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
AbortReq(
req_to_abort.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message,
},
)
)
return req_to_abort.rid == recv_req.rid
def handle_embedding_request(
self,
@@ -1420,6 +1485,7 @@ class Scheduler(
recv_req.input_ids,
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
)
req.tokenizer = self.tokenizer
@@ -1680,6 +1746,10 @@ class Scheduler(
if self.grammar_queue:
self.move_ready_grammar_requests()
if self.try_preemption:
# Reset batch_is_full to try preemption with a prefill adder.
self.running_batch.batch_is_full = False
# Handle the cases where prefill is not allowed
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1692,7 +1762,11 @@ class Scheduler(
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
if (
self.get_num_allocatable_reqs(running_bs) <= 0
and not self.chunked_req
and not self.try_preemption
):
self.running_batch.batch_is_full = True
return None
@@ -1712,6 +1786,7 @@ class Scheduler(
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
if self.chunked_req is not None:
@@ -1732,15 +1807,19 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True
if self.running_batch.batch_is_full:
if not self.try_preemption:
break
if not adder.preempt_to_schedule(req, self.server_args):
break
if self.enable_hicache_storage:
@@ -1777,6 +1856,8 @@ class Scheduler(
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
if adder.new_chunked_req is not None:
assert self.chunked_req is None

View File

@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs,
token_type_ids,
sampling_params,
priority=obj.priority,
)
return tokenized_obj

View File

@@ -149,8 +149,8 @@ class TpModelWorker:
assert self.max_running_requests > 0, "max_running_request is zero"
self.max_queued_requests = server_args.max_queued_requests
assert (
self.max_queued_requests > 0
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
self.max_queued_requests is None or self.max_queued_requests >= 1
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
self.max_req_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,

View File

@@ -172,11 +172,14 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = sys.maxsize
max_queued_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "fcfs"
enable_priority_scheduling: bool = False
schedule_low_priority_values_first: bool = False
priority_scheduling_preemption_threshold: int = 10
schedule_conservativeness: float = 1.0
page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None
@@ -1166,6 +1169,24 @@ class ServerArgs:
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
help="The scheduling policy of the requests.",
)
parser.add_argument(
"--enable-priority-scheduling",
action="store_true",
default=ServerArgs.enable_priority_scheduling,
help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.",
)
parser.add_argument(
"--schedule-low-priority-values-first",
action="store_true",
default=ServerArgs.schedule_low_priority_values_first,
help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.",
)
parser.add_argument(
"--priority-scheduling-preemption-threshold",
type=int,
default=ServerArgs.priority_scheduling_preemption_threshold,
help="Minimum difference in priorities for an incoming request to have to preempt running request(s).",
)
parser.add_argument(
"--schedule-conservativeness",
type=float,
@@ -2455,6 +2476,13 @@ class ServerArgs:
"--generation-tokens-buckets", self.generation_tokens_buckets
)
# Check scheduling policy
if self.enable_priority_scheduling:
assert self.schedule_policy in [
"fcfs",
"lof",
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import aiohttp
import numpy as np
@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests(
return await asyncio.gather(*tasks)
async def send_concurrent_generate_requests_with_custom_params(
base_url: str,
custom_params: List[dict[str, Any]],
) -> Tuple[int, Any]:
"""Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""
base_payload = {
"text": """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
""",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
}
async def async_generate_with_priority(req):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{base_url}/generate",
json=req,
) as response:
resp_json = await response.json()
return (response.status, resp_json)
tasks = []
for c in custom_params:
req = base_payload.copy()
req.update(c)
tasks.append(asyncio.create_task(async_generate_with_priority(req)))
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
max_retry = int(