make radix cache deterministic (#10721)
Signed-off-by: Alex Chi Z <iskyzh@gmail.com>
This commit is contained in:
@@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import (
|
||||
)
|
||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||
from sglang.srt.utils import (
|
||||
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG,
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
configure_gc_logger,
|
||||
@@ -705,11 +706,7 @@ class Scheduler(
|
||||
self.truncation_align_size = None
|
||||
return
|
||||
|
||||
backend_sizes = {
|
||||
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
||||
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
||||
}
|
||||
env_var, default_size = backend_sizes.get(
|
||||
env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get(
|
||||
self.server_args.attention_backend, (None, None)
|
||||
)
|
||||
self.truncation_align_size = (
|
||||
@@ -849,6 +846,7 @@ class Scheduler(
|
||||
disable=server_args.disable_radix_cache,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
eviction_policy=server_args.radix_eviction_policy,
|
||||
enable_deterministic_inference=server_args.enable_deterministic_inference,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
|
||||
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache):
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
eviction_policy: str = "lru",
|
||||
enable_deterministic_inference: bool = False,
|
||||
is_eagle: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
@@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache):
|
||||
self.disable = disable
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue = []
|
||||
self.enable_deterministic_inference = enable_deterministic_inference
|
||||
self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
|
||||
self.is_eagle = is_eagle
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
@@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache):
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
def match_prefix(
|
||||
self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs
|
||||
) -> MatchResult:
|
||||
"""Find the longest cached prefix of ``key`` in the radix tree.
|
||||
|
||||
The logical namespace for prefix matching is determined by both the
|
||||
@@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache):
|
||||
if len(key) == 0:
|
||||
return empty_match_result()
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
value, last_node = self._match_prefix_helper(
|
||||
self.root_node, key, is_cache_unfinished=is_cache_unfinished
|
||||
)
|
||||
if value:
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
@@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key),
|
||||
is_cache_unfinished=True,
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
||||
@@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
||||
def _match_prefix_helper(
|
||||
self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool
|
||||
):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
align_split_size = (
|
||||
not is_cache_unfinished and self.enable_deterministic_inference
|
||||
)
|
||||
match_history = [node] if align_split_size else None
|
||||
|
||||
if align_split_size and len(key) < self.split_size:
|
||||
# fast path: directly return the root node if the split point is 0
|
||||
return value, node
|
||||
|
||||
# use the access history to first find a split point at split_size and then return the value and node at that point.
|
||||
def reconstruct_at_split_point(match_history, value_len):
|
||||
# reverse the search process to find the last node right above the split_size, split here
|
||||
split_point = value_len // self.split_size * self.split_size
|
||||
# rebuild value form history
|
||||
value = []
|
||||
current_value_len = 0
|
||||
node = match_history[0] # this is the root node
|
||||
for idx, node in enumerate(match_history):
|
||||
match_len = len(node.value)
|
||||
if current_value_len + match_len > split_point:
|
||||
# split the node at the desired split point
|
||||
node = self._split_node(
|
||||
node.key, node, split_point - current_value_len
|
||||
)
|
||||
value.append(node.value)
|
||||
return value, node
|
||||
elif current_value_len + match_len == split_point:
|
||||
if idx != 0:
|
||||
value.append(node.value)
|
||||
return value, node
|
||||
current_value_len += match_len
|
||||
if idx != 0:
|
||||
# the root node always has empty value, skip
|
||||
value.append(node.value)
|
||||
# return the root node as the corresponding node doesn't exist yet
|
||||
# and the previously computed node is not at the split boundary
|
||||
return [], match_history[0]
|
||||
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if align_split_size:
|
||||
match_history.append(child)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
@@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache):
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if align_split_size:
|
||||
value_len = sum(map(len, value))
|
||||
value, node = reconstruct_at_split_point(match_history, value_len)
|
||||
assert (
|
||||
sum(map(len, value)) % self.split_size == 0
|
||||
), "The value length is not aligned with the split size"
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||
|
||||
@@ -1381,13 +1381,6 @@ class ServerArgs:
|
||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||
)
|
||||
|
||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||
if self.attention_backend != "fa3":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning(
|
||||
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
|
||||
)
|
||||
|
||||
# Check TP size
|
||||
if self.tp_size > 1:
|
||||
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||
|
||||
@@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None):
|
||||
return CachedKernel(fn, key_fn)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096
|
||||
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = {
|
||||
"flashinfer": (
|
||||
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE",
|
||||
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
|
||||
),
|
||||
"triton": (
|
||||
"SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE",
|
||||
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -277,9 +277,10 @@ def test_deterministic(args):
|
||||
|
||||
elif args.test_mode == "prefix":
|
||||
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
||||
len_prefix = [1, 511, 2048, 4097]
|
||||
len_prefix = [1, 8000, 10000, 12500]
|
||||
num_prompts = len(len_prefix)
|
||||
outputs = {i: [] for i in range(4)}
|
||||
assert all(i <= len(LONG_PROMPT) for i in len_prefix)
|
||||
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
||||
for i in range(args.n_start, args.n_start + args.n_trials):
|
||||
batch_size = i
|
||||
|
||||
Reference in New Issue
Block a user