Files
sglang/python/sglang/srt/managers/schedule_policy.py

675 lines
26 KiB
Python

from __future__ import annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Request scheduler policy"""
import os
import random
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (against existing cache) less than this value,
# the scheduler runs the in-batch prefix caching check for this request.
# If we set it to -1, it means we disable in-batch prefix caching.
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
)
# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (within the waiting queue) larger than this value,
# the scheduler deprioritizes this request
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
)
IGNORE_EOS_RESERVE_TOKENS = 1
class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""
LPM = "lpm" # longest prefix match
DFS_WEIGHT = "dfs-weight" # depth-first search weighting
class CacheAgnosticPolicy(Enum):
"""Scheduling policies that are not aware of the tree cache."""
FCFS = "fcfs" # first come first serve
LOF = "lof" # longest output first
RANDOM = "random"
class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool,
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
self.enable_priority_scheduling = enable_priority_scheduling
self.schedule_low_priority_values_first = schedule_low_priority_values_first
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
if self.enable_priority_scheduling:
SchedulePolicy._sort_by_priority_and_fcfs(
waiting_queue, self.schedule_low_priority_values_first
)
return False
policy = self._determine_active_policy(waiting_queue)
prefix_computed = False
if isinstance(policy, CacheAwarePolicy):
prefix_computed = True
temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(
waiting_queue,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy
def _validate_and_adjust_policy(
self, policy: str, tree_cache: BasePrefixCache
) -> Policy:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try:
policy_enum = CacheAwarePolicy(policy)
if getattr(tree_cache, "disable", True):
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS
return policy_enum
except ValueError:
try:
return CacheAgnosticPolicy(policy)
except ValueError:
raise ValueError(f"Unknown schedule_policy: {policy=}")
def _compute_prefix_matches(
self, waiting_queue: List[Req], policy: CacheAwarePolicy
) -> Set[int]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized: Set[int] = set()
self.waiting_queue_radix_tree.reset()
for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _, _, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized.add(r.rid)
else:
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)
return temporary_deprioritized
@staticmethod
def _sort_by_longest_prefix(
waiting_queue: List[Req], temporary_deprioritized: Set[int]
) -> None:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
)
)
@staticmethod
def _sort_by_dfs_weight(
waiting_queue: List[Req], tree_cache: BasePrefixCache
) -> None:
"""Sorts the waiting queue based on a depth-first search weighting."""
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
waiting_queue.clear()
SchedulePolicy._get_dfs_priority(
tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
@staticmethod
def _sort_by_longest_output(
waiting_queue: List[Req],
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
if enable_priority_scheduling:
if schedule_low_priority_values_first:
waiting_queue.sort(
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
@staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)
@staticmethod
def _sort_by_priority_and_fcfs(
waiting_queue: List[Req], schedule_low_priority_values_first: bool
) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values():
SchedulePolicy._calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
@staticmethod
def _get_dfs_priority(
cur_node: TreeNode,
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
) -> None:
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
SchedulePolicy._get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q
)
q.extend(last_node_to_reqs[cur_node])
class AddReqResult(Enum):
CONTINUE = auto() # Continue to add requests
NO_TOKEN = auto() # No token left
OTHER = auto() # Other reasons to stop adding requests
class PrefillAdder:
def __init__(
self,
page_size: int,
tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
self.page_size = page_size
self.tree_cache = tree_cache
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.rem_total_token_offset = mixed_with_decode_tokens
self.cur_rem_token_offset = mixed_with_decode_tokens
self.req_states = None
self.can_run_list = []
self.preempt_list = []
self.new_chunked_req = None
self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
self.log_input_tokens = 0
if running_batch is not None:
self.rem_total_token_offset += sum(
[
self._get_running_request_total_token_offset(r)
for r in running_batch.reqs
]
)
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
(req.sampling_params.max_new_tokens - len(req.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
)
@property
def rem_total_tokens(self):
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.rem_total_token_offset
@property
def cur_rem_tokens(self):
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.cur_rem_token_offset
def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size
def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
if self.rem_input_tokens <= 0 or (
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
):
return AddReqResult.OTHER
return AddReqResult.CONTINUE
def _update_prefill_budget(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
extend_input_len = self.ceil_paged_tokens(extend_input_len)
self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len
def add_chunked_req(self, req: Req):
_rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
truncated = req.extend_input_len > _rem_tokens
req.extend_input_len = min(req.extend_input_len, _rem_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._update_prefill_budget(
0,
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)
# Return if chunked prefill not finished
return req if truncated else None
@contextmanager
def _lock_node(self, last_node: TreeNode):
if self.is_hybrid:
try:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
else:
try:
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
if self.ceil_paged_tokens(req.extend_input_len) > min(
self.cur_rem_tokens, self.rem_total_tokens
):
return AddReqResult.NO_TOKEN
def add_req_state(r, insert_sort=False):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
)
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
r.output_ids
)
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
if tokens_left <= 0:
return
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))
if self.req_states is None:
self.req_states = []
add_req_state(req)
if self.running_batch is not None:
for r in self.running_batch.reqs:
add_req_state(r)
for r in self.can_run_list:
add_req_state(r)
self.req_states.sort(key=lambda x: x[0])
else:
add_req_state(req, insert_sort=True)
if not self.is_hybrid:
# Skip this logic for swa. The SWA has different memory management, and
# this mechanism is underestimating the memory usage.
cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
req.extend_input_len
)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
):
# Non-chunked prefill
self.can_run_list.append(req)
self._update_prefill_budget(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req
self._update_prefill_budget(0, trunc_len, 0)
return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool):
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
# adjusting the input_tokens based on host_hit_length and page_size
real_input_tokens = req.extend_input_len - req.host_hit_length
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
with self._lock_node(req.last_node):
# self.rem_total_tokens may decrease after the lock acquisition
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if req.host_hit_length > 0:
new_indices, req.last_node = self.tree_cache.init_load_back(
req.last_host_node, req.host_hit_length
)
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
prefix_len = len(req.prefix_indices)
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(
prefix_len,
input_tokens,
min(
req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKENS,
),
)
else:
# Make sure at least one page is available
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
if trunc_len <= 0:
return AddReqResult.OTHER
# Chunked prefill
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
"""
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
Returns True if preemption was committed, and the new request can be scheduled.
"""
# Iterate running requests to find preemptible requests
if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start),
)
else:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start),
)
preemptible_reqs = []
min_tokens_to_remove = (
req.extend_input_len
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
- self.rem_total_tokens
)
for running_req in sorted_running_reqs:
if running_req in self.preempt_list:
continue
# Priority difference needs to meet the threshold to be preemptible.
priority_diff = req.priority - running_req.priority
if server_args.schedule_low_priority_values_first:
priority_diff *= -1
if priority_diff > self.priority_scheduling_preemption_threshold:
preemptible_reqs.append(running_req)
min_tokens_to_remove -= self._get_running_request_total_token_offset(
running_req
)
# Check max token count limit can be met
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
return False
# Preempt running requests. Release allocated resources for immediate usage.
preemptible_reqs = set(preemptible_reqs)
keep_indices = []
release_counter = 0
for i, running_req in enumerate(self.running_batch.reqs):
if running_req in preemptible_reqs:
self.rem_total_token_offset -= (
self._get_running_request_total_token_offset(req)
)
release_counter += 1
self.running_batch.release_req(
i, len(self.running_batch.reqs) - release_counter, server_args
)
else:
keep_indices.append(i)
self.running_batch.filter_batch(keep_indices=keep_indices)
self.preempt_list.extend(preemptible_reqs)
return True