Reduce the overhead when cache is disabled (#1010)
This commit is contained in:
@@ -18,44 +18,40 @@ limitations under the License.
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
|
||||
|
||||
class PolicyScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
max_running_seqs,
|
||||
max_prefill_num_tokens,
|
||||
max_total_num_tokens,
|
||||
tree_cache,
|
||||
):
|
||||
if tree_cache.disable and policy == "lpm":
|
||||
# LMP is meaningless when the tree cache is disabled.
|
||||
def __init__(self, policy, tree_cache):
|
||||
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
||||
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
||||
policy = "fcfs"
|
||||
|
||||
self.policy = policy
|
||||
self.max_running_seqs = max_running_seqs
|
||||
self.max_prefill_num_tokens = max_prefill_num_tokens
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def get_priority_queue(self, waiting_queue):
|
||||
def calc_priority(self, waiting_queue: List[Req]):
|
||||
if self.policy in ["lpm", "dfs-weight"]:
|
||||
# Compute matched prefix length
|
||||
for r in waiting_queue:
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=r.adjust_max_prefix_ids()
|
||||
)
|
||||
|
||||
if self.policy == "lpm":
|
||||
# longest prefix match
|
||||
# Longest Prefix Match
|
||||
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||
return waiting_queue
|
||||
elif self.policy == "fcfs":
|
||||
# first come first serve
|
||||
return waiting_queue
|
||||
pass
|
||||
elif self.policy == "lof":
|
||||
# longest output first
|
||||
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||
return waiting_queue
|
||||
elif self.policy == "random":
|
||||
random.shuffle(waiting_queue)
|
||||
return waiting_queue
|
||||
elif self.policy == "dfs-weight":
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in waiting_queue:
|
||||
@@ -66,12 +62,13 @@ class PolicyScheduler:
|
||||
node_to_weight[node] = len(last_node_to_reqs[node])
|
||||
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
||||
|
||||
q = []
|
||||
waiting_queue.clear()
|
||||
self.get_dfs_priority(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
||||
self.tree_cache.root_node,
|
||||
node_to_weight,
|
||||
last_node_to_reqs,
|
||||
waiting_queue,
|
||||
)
|
||||
assert len(q) == len(waiting_queue)
|
||||
return q
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
||||
|
||||
@@ -139,8 +136,6 @@ class PrefillAdder:
|
||||
self.log_input_tokens += extend_input_len
|
||||
|
||||
def add_inflight_req(self, req: Req):
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
|
||||
@@ -164,7 +164,12 @@ class Req:
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
def init_next_round_input(self):
|
||||
self.input_ids = self.origin_input_ids + self.output_ids
|
||||
self.extend_input_len = len(self.input_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
self.input_ids = self.origin_input_ids + self.output_ids
|
||||
input_len = len(self.input_ids)
|
||||
max_prefix_len = input_len
|
||||
|
||||
|
||||
@@ -165,13 +165,7 @@ class ModelTpServer:
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = PolicyScheduler(
|
||||
self.schedule_policy,
|
||||
self.max_running_requests,
|
||||
self.max_prefill_tokens,
|
||||
self.max_total_num_tokens,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
|
||||
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
@@ -373,17 +367,8 @@ class ModelTpServer:
|
||||
if running_bs >= self.max_running_requests:
|
||||
return None
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.waiting_queue:
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
|
||||
rid=req.rid, key=req.adjust_max_prefix_ids()
|
||||
)
|
||||
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
|
||||
|
||||
# Get priority queue
|
||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||
self.scheduler.calc_priority(self.waiting_queue)
|
||||
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
@@ -397,12 +382,13 @@ class ModelTpServer:
|
||||
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
if self.current_inflight_req is not None:
|
||||
self.current_inflight_req.init_next_round_input()
|
||||
self.current_inflight_req = adder.add_inflight_req(
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
|
||||
req.init_next_round_input()
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
|
||||
@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache):
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_lock_ref(self, node: TreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 0:
|
||||
@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache):
|
||||
return delta
|
||||
|
||||
def dec_lock_ref(self, node: TreeNode):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 1:
|
||||
|
||||
Reference in New Issue
Block a user