make radix cache deterministic (#10721)
Signed-off-by: Alex Chi Z <iskyzh@gmail.com>
This commit is contained in:
@@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG,
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_gc_logger,
|
configure_gc_logger,
|
||||||
@@ -705,11 +706,7 @@ class Scheduler(
|
|||||||
self.truncation_align_size = None
|
self.truncation_align_size = None
|
||||||
return
|
return
|
||||||
|
|
||||||
backend_sizes = {
|
env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get(
|
||||||
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
|
||||||
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
|
||||||
}
|
|
||||||
env_var, default_size = backend_sizes.get(
|
|
||||||
self.server_args.attention_backend, (None, None)
|
self.server_args.attention_backend, (None, None)
|
||||||
)
|
)
|
||||||
self.truncation_align_size = (
|
self.truncation_align_size = (
|
||||||
@@ -849,6 +846,7 @@ class Scheduler(
|
|||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
eviction_policy=server_args.radix_eviction_policy,
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
|
enable_deterministic_inference=server_args.enable_deterministic_inference,
|
||||||
is_eagle=self.spec_algorithm.is_eagle(),
|
is_eagle=self.spec_algorithm.is_eagle(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2023-2024 SGLang Team
|
Copyright 2023-2024 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
disable: bool = False,
|
disable: bool = False,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
eviction_policy: str = "lru",
|
eviction_policy: str = "lru",
|
||||||
|
enable_deterministic_inference: bool = False,
|
||||||
is_eagle: bool = False,
|
is_eagle: bool = False,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
@@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.disable = disable
|
self.disable = disable
|
||||||
self.enable_kv_cache_events = enable_kv_cache_events
|
self.enable_kv_cache_events = enable_kv_cache_events
|
||||||
self.kv_event_queue = []
|
self.kv_event_queue = []
|
||||||
|
self.enable_deterministic_inference = enable_deterministic_inference
|
||||||
|
self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
|
||||||
self.is_eagle = is_eagle
|
self.is_eagle = is_eagle
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator:
|
if self.token_to_kv_pool_allocator:
|
||||||
@@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.protected_size_ = 0
|
self.protected_size_ = 0
|
||||||
self._record_all_cleared_event()
|
self._record_all_cleared_event()
|
||||||
|
|
||||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
def match_prefix(
|
||||||
|
self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs
|
||||||
|
) -> MatchResult:
|
||||||
"""Find the longest cached prefix of ``key`` in the radix tree.
|
"""Find the longest cached prefix of ``key`` in the radix tree.
|
||||||
|
|
||||||
The logical namespace for prefix matching is determined by both the
|
The logical namespace for prefix matching is determined by both the
|
||||||
@@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
if len(key) == 0:
|
if len(key) == 0:
|
||||||
return empty_match_result()
|
return empty_match_result()
|
||||||
|
|
||||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
value, last_node = self._match_prefix_helper(
|
||||||
|
self.root_node, key, is_cache_unfinished=is_cache_unfinished
|
||||||
|
)
|
||||||
if value:
|
if value:
|
||||||
value = torch.cat(value)
|
value = torch.cat(value)
|
||||||
else:
|
else:
|
||||||
@@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
# The prefix indices could be updated, reuse it
|
# The prefix indices could be updated, reuse it
|
||||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key),
|
||||||
|
is_cache_unfinished=True,
|
||||||
)
|
)
|
||||||
self.req_to_token_pool.write(
|
self.req_to_token_pool.write(
|
||||||
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
||||||
@@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
##### Internal Helper Functions #####
|
##### Internal Helper Functions #####
|
||||||
|
|
||||||
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
def _match_prefix_helper(
|
||||||
|
self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool
|
||||||
|
):
|
||||||
node.last_access_time = time.monotonic()
|
node.last_access_time = time.monotonic()
|
||||||
|
|
||||||
child_key = self.get_child_key_fn(key)
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
value = []
|
value = []
|
||||||
|
align_split_size = (
|
||||||
|
not is_cache_unfinished and self.enable_deterministic_inference
|
||||||
|
)
|
||||||
|
match_history = [node] if align_split_size else None
|
||||||
|
|
||||||
|
if align_split_size and len(key) < self.split_size:
|
||||||
|
# fast path: directly return the root node if the split point is 0
|
||||||
|
return value, node
|
||||||
|
|
||||||
|
# use the access history to first find a split point at split_size and then return the value and node at that point.
|
||||||
|
def reconstruct_at_split_point(match_history, value_len):
|
||||||
|
# reverse the search process to find the last node right above the split_size, split here
|
||||||
|
split_point = value_len // self.split_size * self.split_size
|
||||||
|
# rebuild value form history
|
||||||
|
value = []
|
||||||
|
current_value_len = 0
|
||||||
|
node = match_history[0] # this is the root node
|
||||||
|
for idx, node in enumerate(match_history):
|
||||||
|
match_len = len(node.value)
|
||||||
|
if current_value_len + match_len > split_point:
|
||||||
|
# split the node at the desired split point
|
||||||
|
node = self._split_node(
|
||||||
|
node.key, node, split_point - current_value_len
|
||||||
|
)
|
||||||
|
value.append(node.value)
|
||||||
|
return value, node
|
||||||
|
elif current_value_len + match_len == split_point:
|
||||||
|
if idx != 0:
|
||||||
|
value.append(node.value)
|
||||||
|
return value, node
|
||||||
|
current_value_len += match_len
|
||||||
|
if idx != 0:
|
||||||
|
# the root node always has empty value, skip
|
||||||
|
value.append(node.value)
|
||||||
|
# return the root node as the corresponding node doesn't exist yet
|
||||||
|
# and the previously computed node is not at the split boundary
|
||||||
|
return [], match_history[0]
|
||||||
|
|
||||||
while len(key) > 0 and child_key in node.children.keys():
|
while len(key) > 0 and child_key in node.children.keys():
|
||||||
child = node.children[child_key]
|
child = node.children[child_key]
|
||||||
child.last_access_time = time.monotonic()
|
child.last_access_time = time.monotonic()
|
||||||
prefix_len = self.key_match_fn(child.key, key)
|
prefix_len = self.key_match_fn(child.key, key)
|
||||||
|
if align_split_size:
|
||||||
|
match_history.append(child)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
value.append(new_node.value)
|
value.append(new_node.value)
|
||||||
@@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache):
|
|||||||
if len(key):
|
if len(key):
|
||||||
child_key = self.get_child_key_fn(key)
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
|
if align_split_size:
|
||||||
|
value_len = sum(map(len, value))
|
||||||
|
value, node = reconstruct_at_split_point(match_history, value_len)
|
||||||
|
assert (
|
||||||
|
sum(map(len, value)) % self.split_size == 0
|
||||||
|
), "The value length is not aligned with the split size"
|
||||||
|
|
||||||
return value, node
|
return value, node
|
||||||
|
|
||||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
||||||
|
|||||||
@@ -1381,13 +1381,6 @@ class ServerArgs:
|
|||||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
|
||||||
if self.attention_backend != "fa3":
|
|
||||||
self.disable_radix_cache = True
|
|
||||||
logger.warning(
|
|
||||||
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check TP size
|
# Check TP size
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||||
|
|||||||
@@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None):
|
|||||||
return CachedKernel(fn, key_fn)
|
return CachedKernel(fn, key_fn)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096
|
||||||
|
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = {
|
||||||
|
"flashinfer": (
|
||||||
|
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE",
|
||||||
|
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
|
||||||
|
),
|
||||||
|
"triton": (
|
||||||
|
"SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE",
|
||||||
|
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|||||||
@@ -277,9 +277,10 @@ def test_deterministic(args):
|
|||||||
|
|
||||||
elif args.test_mode == "prefix":
|
elif args.test_mode == "prefix":
|
||||||
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
||||||
len_prefix = [1, 511, 2048, 4097]
|
len_prefix = [1, 8000, 10000, 12500]
|
||||||
num_prompts = len(len_prefix)
|
num_prompts = len(len_prefix)
|
||||||
outputs = {i: [] for i in range(4)}
|
outputs = {i: [] for i in range(4)}
|
||||||
|
assert all(i <= len(LONG_PROMPT) for i in len_prefix)
|
||||||
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
||||||
for i in range(args.n_start, args.n_start + args.n_trials):
|
for i in range(args.n_start, args.n_start + args.n_trials):
|
||||||
batch_size = i
|
batch_size = i
|
||||||
|
|||||||
Reference in New Issue
Block a user