From dc965db0e0c7be118c6d83184ea13cd56e054e89 Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:01:52 +0200 Subject: [PATCH] make radix cache deterministic (#10721) Signed-off-by: Alex Chi Z --- python/sglang/srt/managers/scheduler.py | 8 +-- python/sglang/srt/mem_cache/radix_cache.py | 67 ++++++++++++++++++++-- python/sglang/srt/server_args.py | 7 --- python/sglang/srt/utils/common.py | 13 +++++ python/sglang/test/test_deterministic.py | 3 +- 5 files changed, 81 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 10110751c..3e0a8945d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import ( ) from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.utils import ( + DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG, DynamicGradMode, broadcast_pyobj, configure_gc_logger, @@ -705,11 +706,7 @@ class Scheduler( self.truncation_align_size = None return - backend_sizes = { - "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096), - "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096), - } - env_var, default_size = backend_sizes.get( + env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get( self.server_args.attention_backend, (None, None) ) self.truncation_align_size = ( @@ -849,6 +846,7 @@ class Scheduler( disable=server_args.disable_radix_cache, enable_kv_cache_events=self.enable_kv_cache_events, eviction_policy=server_args.radix_eviction_policy, + enable_deterministic_inference=server_args.enable_deterministic_inference, is_eagle=self.spec_algorithm.is_eagle(), ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index bed7923f6..2ffc088df 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -1,5 +1,7 @@ from __future__ import annotations +from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache): disable: bool = False, enable_kv_cache_events: bool = False, eviction_policy: str = "lru", + enable_deterministic_inference: bool = False, is_eagle: bool = False, ): self.req_to_token_pool = req_to_token_pool @@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache): self.disable = disable self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue = [] + self.enable_deterministic_inference = enable_deterministic_inference + self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE self.is_eagle = is_eagle if self.token_to_kv_pool_allocator: @@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache): self.protected_size_ = 0 self._record_all_cleared_event() - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + def match_prefix( + self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs + ) -> MatchResult: """Find the longest cached prefix of ``key`` in the radix tree. The logical namespace for prefix matching is determined by both the @@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache): if len(key) == 0: return empty_match_result() - value, last_node = self._match_prefix_helper(self.root_node, key) + value, last_node = self._match_prefix_helper( + self.root_node, key, is_cache_unfinished=is_cache_unfinished + ) if value: value = torch.cat(value) else: @@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache): # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( - RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) + RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key), + is_cache_unfinished=True, ) self.req_to_token_pool.write( (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), @@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper(self, node: TreeNode, key: RadixKey): + def _match_prefix_helper( + self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool + ): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) value = [] + align_split_size = ( + not is_cache_unfinished and self.enable_deterministic_inference + ) + match_history = [node] if align_split_size else None + + if align_split_size and len(key) < self.split_size: + # fast path: directly return the root node if the split point is 0 + return value, node + + # use the access history to first find a split point at split_size and then return the value and node at that point. + def reconstruct_at_split_point(match_history, value_len): + # reverse the search process to find the last node right above the split_size, split here + split_point = value_len // self.split_size * self.split_size + # rebuild value form history + value = [] + current_value_len = 0 + node = match_history[0] # this is the root node + for idx, node in enumerate(match_history): + match_len = len(node.value) + if current_value_len + match_len > split_point: + # split the node at the desired split point + node = self._split_node( + node.key, node, split_point - current_value_len + ) + value.append(node.value) + return value, node + elif current_value_len + match_len == split_point: + if idx != 0: + value.append(node.value) + return value, node + current_value_len += match_len + if idx != 0: + # the root node always has empty value, skip + value.append(node.value) + # return the root node as the corresponding node doesn't exist yet + # and the previously computed node is not at the split boundary + return [], match_history[0] + while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] child.last_access_time = time.monotonic() prefix_len = self.key_match_fn(child.key, key) + if align_split_size: + match_history.append(child) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) @@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache): if len(key): child_key = self.get_child_key_fn(key) + if align_split_size: + value_len = sum(map(len, value)) + value, node = reconstruct_at_split_point(match_history, value_len) + assert ( + sum(map(len, value)) % self.split_size == 0 + ), "The value length is not aligned with the split size" + return value, node def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e5bce7457..07a122f65 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1381,13 +1381,6 @@ class ServerArgs: f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." ) - # Currently, only FA3 supports radix cache. Support for other backends is in progress - if self.attention_backend != "fa3": - self.disable_radix_cache = True - logger.warning( - f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future." - ) - # Check TP size if self.tp_size > 1: os.environ["NCCL_ALGO"] = "allreduce:tree" diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 084065b61..5f735fe88 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None): return CachedKernel(fn, key_fn) return decorator + + +DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096 +DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = { + "flashinfer": ( + "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", + DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE, + ), + "triton": ( + "SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", + DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE, + ), +} diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py index 8c513cb6a..8556c03a7 100644 --- a/python/sglang/test/test_deterministic.py +++ b/python/sglang/test/test_deterministic.py @@ -277,9 +277,10 @@ def test_deterministic(args): elif args.test_mode == "prefix": # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix. - len_prefix = [1, 511, 2048, 4097] + len_prefix = [1, 8000, 10000, 12500] num_prompts = len(len_prefix) outputs = {i: [] for i in range(4)} + assert all(i <= len(LONG_PROMPT) for i in len_prefix) prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)] for i in range(args.n_start, args.n_start + args.n_trials): batch_size = i