321 lines
12 KiB
Python
321 lines
12 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
"""Request scheduler policy"""
|
|
|
|
import os
|
|
import random
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from enum import Enum, auto
|
|
from typing import Dict, List, Optional
|
|
|
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
from sglang.srt.mem_cache.radix_cache import TreeNode
|
|
|
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
|
# This can prevent the server from being too conservative.
|
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
|
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
|
|
|
|
|
|
class SchedulePolicy:
|
|
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
|
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
|
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
|
policy = "fcfs"
|
|
|
|
self.policy = policy
|
|
self.tree_cache = tree_cache
|
|
|
|
def calc_priority(self, waiting_queue: List[Req]):
|
|
# Compute matched prefix length
|
|
prefix_computed = False
|
|
if self.policy == "lpm" or self.policy == "dfs-weight":
|
|
for r in waiting_queue:
|
|
# NOTE: the prefix_indices must always be aligned with last_node
|
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
|
rid=r.rid, key=r.adjust_max_prefix_ids()
|
|
)
|
|
|
|
prefix_computed = True
|
|
|
|
if self.policy == "lpm":
|
|
# Longest Prefix Match
|
|
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
|
elif self.policy == "fcfs":
|
|
# first come first serve
|
|
pass
|
|
elif self.policy == "lof":
|
|
# longest output first
|
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
|
elif self.policy == "random":
|
|
random.shuffle(waiting_queue)
|
|
elif self.policy == "dfs-weight":
|
|
last_node_to_reqs = defaultdict(list)
|
|
for req in waiting_queue:
|
|
last_node_to_reqs[req.last_node].append(req)
|
|
|
|
node_to_weight = defaultdict(int)
|
|
for node in last_node_to_reqs:
|
|
node_to_weight[node] = len(last_node_to_reqs[node])
|
|
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
|
|
|
waiting_queue.clear()
|
|
self.get_dfs_priority(
|
|
self.tree_cache.root_node,
|
|
node_to_weight,
|
|
last_node_to_reqs,
|
|
waiting_queue,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
|
|
|
return prefix_computed
|
|
|
|
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
|
|
for child in cur_node.children.values():
|
|
self.calc_weight(child, node_to_weight)
|
|
node_to_weight[cur_node] += node_to_weight[child]
|
|
|
|
def get_dfs_priority(
|
|
self,
|
|
cur_node: TreeNode,
|
|
node_to_priority: Dict,
|
|
last_node_to_reqs: Dict,
|
|
q: List,
|
|
):
|
|
childs = [child for child in cur_node.children.values()]
|
|
childs.sort(key=lambda x: -node_to_priority[x])
|
|
for child in childs:
|
|
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
|
q.extend(last_node_to_reqs[cur_node])
|
|
|
|
|
|
class AddReqResult(Enum):
|
|
CONTINUE = auto() # Continue to add requests
|
|
NO_TOKEN = auto() # No token left
|
|
OTHER = auto() # Other reasons to stop adding requests
|
|
|
|
|
|
class PrefillAdder:
|
|
def __init__(
|
|
self,
|
|
tree_cache: BasePrefixCache,
|
|
running_batch: ScheduleBatch,
|
|
new_token_ratio: float,
|
|
rem_total_tokens: int,
|
|
rem_input_tokens: int,
|
|
rem_chunk_tokens: Optional[int],
|
|
mixed_with_decode_tokens: int = 0,
|
|
):
|
|
self.tree_cache = tree_cache
|
|
self.running_batch = running_batch
|
|
self.new_token_ratio = new_token_ratio
|
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
|
self.rem_chunk_tokens = rem_chunk_tokens
|
|
if self.rem_chunk_tokens is not None:
|
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
|
|
|
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
|
|
|
self.req_states = None
|
|
self.can_run_list = []
|
|
self.new_inflight_req = None
|
|
self.log_hit_tokens = 0
|
|
self.log_input_tokens = 0
|
|
|
|
if running_batch is not None:
|
|
# Pre-remove the tokens which will be occupied by the running requests
|
|
self.rem_total_tokens -= sum(
|
|
[
|
|
min(
|
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
|
CLIP_MAX_NEW_TOKENS,
|
|
)
|
|
* self.new_token_ratio
|
|
for r in running_batch.reqs
|
|
]
|
|
)
|
|
|
|
def budget_state(self):
|
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
|
return AddReqResult.NO_TOKEN
|
|
|
|
if self.rem_input_tokens <= 0 or (
|
|
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
|
|
):
|
|
return AddReqResult.OTHER
|
|
|
|
return AddReqResult.CONTINUE
|
|
|
|
def _prefill_one_req(
|
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
|
):
|
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
|
self.cur_rem_tokens -= extend_input_len
|
|
self.rem_input_tokens -= extend_input_len
|
|
if self.rem_chunk_tokens is not None:
|
|
self.rem_chunk_tokens -= extend_input_len
|
|
|
|
self.log_hit_tokens += prefix_len
|
|
self.log_input_tokens += extend_input_len
|
|
|
|
def add_inflight_req(self, req: Req):
|
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
|
self.can_run_list.append(req)
|
|
|
|
self._prefill_one_req(
|
|
len(req.prefix_indices),
|
|
req.extend_input_len,
|
|
(
|
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
|
if not truncated
|
|
else 0
|
|
),
|
|
)
|
|
|
|
# Return if chunked prefill not finished
|
|
return req if truncated else None
|
|
|
|
@contextmanager
|
|
def _lock_node(self, last_node: TreeNode):
|
|
try:
|
|
delta = self.tree_cache.inc_lock_ref(last_node)
|
|
self.rem_total_tokens += delta
|
|
yield None
|
|
finally:
|
|
delta = self.tree_cache.dec_lock_ref(last_node)
|
|
self.rem_total_tokens += delta
|
|
|
|
def add_one_req_ignore_eos(self, req: Req):
|
|
def add_req_state(r, insert_sort=False):
|
|
new_token_ratio = (
|
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
|
)
|
|
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
|
|
r.output_ids
|
|
)
|
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
|
|
|
if tokens_left > 0:
|
|
if not insert_sort:
|
|
self.req_states.append((tokens_left, tokens_occupied))
|
|
else:
|
|
i = 0
|
|
for i in range(len(self.req_states)):
|
|
if tokens_left <= self.req_states[i][0]:
|
|
break
|
|
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
|
|
|
if self.req_states is None:
|
|
self.req_states = []
|
|
add_req_state(req)
|
|
if self.running_batch is not None:
|
|
for r in self.running_batch.reqs:
|
|
add_req_state(r)
|
|
for r in self.can_run_list:
|
|
add_req_state(r)
|
|
self.req_states.sort(key=lambda x: x[0])
|
|
else:
|
|
add_req_state(req, insert_sort=True)
|
|
|
|
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
|
tokens_freed = 0
|
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
|
decode_steps = (
|
|
self.req_states[i + 1][0]
|
|
if i + 1 < len(self.req_states)
|
|
else tokens_left
|
|
)
|
|
bs = len(self.req_states) - i
|
|
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
|
return AddReqResult.NO_TOKEN
|
|
tokens_freed += tokens_occupied
|
|
|
|
if (
|
|
self.rem_chunk_tokens is None
|
|
or req.extend_input_len <= self.rem_chunk_tokens
|
|
):
|
|
self.can_run_list.append(req)
|
|
self._prefill_one_req(
|
|
0,
|
|
req.extend_input_len,
|
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
|
)
|
|
else:
|
|
# Chunked prefill
|
|
trunc_len = self.rem_chunk_tokens
|
|
req.extend_input_len = trunc_len
|
|
req.fill_ids = req.fill_ids[:trunc_len]
|
|
self.can_run_list.append(req)
|
|
self.new_inflight_req = req
|
|
self._prefill_one_req(0, trunc_len, 0)
|
|
|
|
return self.budget_state()
|
|
|
|
def add_one_req(self, req: Req):
|
|
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
|
return self.add_one_req_ignore_eos(req)
|
|
|
|
total_tokens = req.extend_input_len + min(
|
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
|
)
|
|
input_tokens = req.extend_input_len
|
|
prefix_len = len(req.prefix_indices)
|
|
|
|
if total_tokens >= self.rem_total_tokens:
|
|
return AddReqResult.NO_TOKEN
|
|
|
|
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
|
return AddReqResult.OTHER
|
|
|
|
with self._lock_node(req.last_node):
|
|
if total_tokens > self.rem_total_tokens:
|
|
return AddReqResult.NO_TOKEN
|
|
|
|
if (
|
|
self.rem_chunk_tokens is None
|
|
or input_tokens <= self.rem_chunk_tokens
|
|
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
|
):
|
|
# Non-chunked prefill
|
|
self.can_run_list.append(req)
|
|
self.tree_cache.inc_lock_ref(req.last_node)
|
|
self._prefill_one_req(
|
|
prefix_len,
|
|
input_tokens,
|
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
|
)
|
|
else:
|
|
# Chunked prefill
|
|
trunc_len = self.rem_chunk_tokens
|
|
if trunc_len == 0:
|
|
return AddReqResult.OTHER
|
|
|
|
req.extend_input_len = trunc_len
|
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
|
self.can_run_list.append(req)
|
|
self.new_inflight_req = req
|
|
self.tree_cache.inc_lock_ref(req.last_node)
|
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
|
|
|
return self.budget_state()
|