Add a benchmark script for in-batch prefix caching (#2494)
This commit is contained in:
@@ -34,11 +34,19 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
||||
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
||||
)
|
||||
|
||||
# The threshold to apply in-batch prefix caching.
|
||||
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
|
||||
# imagine "the" prefix.
|
||||
IN_BATCH_PREFIX_CACHING_THRESHOLD = int(
|
||||
os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32")
|
||||
# Threshold for in-batch prefix cache.
|
||||
# If a request has a matched prefix length (against existing cache) less than this value,
|
||||
# the scheduler runs the in-batch prefix caching check for this request.
|
||||
# If we set it to -1, it means we disable in-batch prefix caching.
|
||||
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
|
||||
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
|
||||
)
|
||||
|
||||
# Threshold for in-batch prefix cache.
|
||||
# If a request has a matched prefix length (within the waiting queue) larger than this value,
|
||||
# the scheduler deprioritizes this request
|
||||
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
||||
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
|
||||
)
|
||||
|
||||
|
||||
@@ -51,6 +59,11 @@ class SchedulePolicy:
|
||||
self.policy = policy
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
||||
)
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]):
|
||||
if len(waiting_queue) > 128 and self.policy == "lpm":
|
||||
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
||||
@@ -60,50 +73,54 @@ class SchedulePolicy:
|
||||
|
||||
# Compute matched prefix length
|
||||
prefix_computed = False
|
||||
# rid to deprioritize in the current run.
|
||||
temporary_deprioritized = {}
|
||||
if policy == "lpm" or policy == "dfs-weight":
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
temp_radix = RadixCache(None, None, False)
|
||||
# rid to deprioritize in the current run for in-batch prefix caching.
|
||||
temporary_deprioritized = set()
|
||||
self.waiting_queue_radix_tree.reset()
|
||||
|
||||
for r in waiting_queue:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for In-batch prefix caching;
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
# If there are more than 1 request that have small matching prefix from
|
||||
# existing cache, but all those requests share the same prefix, we prefer
|
||||
# to schedule only one of them so that we can increase the cache hit rate.
|
||||
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
|
||||
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
||||
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
||||
# It is kind of common when the engine is long running (e.g., imagine "the").
|
||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD:
|
||||
in_batch_matching_prefixes, _ = temp_radix.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
||||
in_batch_matching_prefixes, _ = (
|
||||
self.waiting_queue_radix_tree.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
)
|
||||
if (
|
||||
len(in_batch_matching_prefixes)
|
||||
>= IN_BATCH_PREFIX_CACHING_THRESHOLD
|
||||
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
||||
):
|
||||
temporary_deprioritized[r.rid] = r
|
||||
temporary_deprioritized.add(r.rid)
|
||||
else:
|
||||
temp_radix.insert(prefix_ids, torch.tensor(prefix_ids))
|
||||
# Insert with a dummy key
|
||||
self.waiting_queue_radix_tree.insert(
|
||||
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
||||
)
|
||||
|
||||
prefix_computed = True
|
||||
|
||||
if policy == "lpm":
|
||||
# Longest Prefix Match
|
||||
def get_priority(r: Req):
|
||||
score = 0
|
||||
if r.rid in temporary_deprioritized:
|
||||
score = float("inf")
|
||||
else:
|
||||
score = -len(r.prefix_indices)
|
||||
return score
|
||||
|
||||
waiting_queue.sort(key=get_priority)
|
||||
waiting_queue.sort(
|
||||
key=lambda r: (
|
||||
-len(r.prefix_indices)
|
||||
if r.rid not in temporary_deprioritized
|
||||
else float("inf")
|
||||
)
|
||||
)
|
||||
elif policy == "fcfs":
|
||||
# first come first serve
|
||||
pass
|
||||
@@ -113,11 +130,11 @@ class SchedulePolicy:
|
||||
elif policy == "random":
|
||||
random.shuffle(waiting_queue)
|
||||
elif policy == "dfs-weight":
|
||||
# Experimental policy based on custom weights
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in waiting_queue:
|
||||
last_node_to_reqs[req.last_node].append(req)
|
||||
|
||||
# node -> # of requests for that node.
|
||||
node_to_weight = defaultdict(int)
|
||||
for node in last_node_to_reqs:
|
||||
node_to_weight[node] = len(last_node_to_reqs[node])
|
||||
@@ -129,9 +146,7 @@ class SchedulePolicy:
|
||||
node_to_weight,
|
||||
last_node_to_reqs,
|
||||
waiting_queue,
|
||||
temporary_deprioritized,
|
||||
)
|
||||
waiting_queue.extend(temporary_deprioritized.values())
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
||||
|
||||
@@ -148,19 +163,12 @@ class SchedulePolicy:
|
||||
node_to_priority: Dict[TreeNode, int],
|
||||
last_node_to_reqs: Dict[TreeNode, List[Req]],
|
||||
q: List,
|
||||
temporary_deprioritized: Dict[str, Req],
|
||||
):
|
||||
childs = [child for child in cur_node.children.values()]
|
||||
childs.sort(key=lambda x: -node_to_priority[x])
|
||||
for child in childs:
|
||||
self.get_dfs_priority(
|
||||
child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized
|
||||
)
|
||||
|
||||
for req in last_node_to_reqs[cur_node]:
|
||||
if req.rid in temporary_deprioritized:
|
||||
continue
|
||||
q.append(req)
|
||||
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
||||
q.extend(last_node_to_reqs[cur_node])
|
||||
|
||||
|
||||
class AddReqResult(Enum):
|
||||
|
||||
Reference in New Issue
Block a user