From b0d1d717e178c8cc554d7b3536458a847fa21b2e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 16 Oct 2025 16:36:15 -0500 Subject: [PATCH] Revert "make radix cache deterministic" (#11728) --- python/sglang/srt/managers/scheduler.py | 8 ++- python/sglang/srt/mem_cache/radix_cache.py | 67 ++-------------------- python/sglang/srt/server_args.py | 7 +++ python/sglang/srt/utils/common.py | 13 ----- python/sglang/test/test_deterministic.py | 3 +- 5 files changed, 17 insertions(+), 81 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ce0148e98..cec6af433 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -163,7 +163,6 @@ from sglang.srt.tracing.trace import ( ) from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.utils import ( - DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG, DynamicGradMode, broadcast_pyobj, configure_gc_logger, @@ -712,7 +711,11 @@ class Scheduler( self.truncation_align_size = None return - env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get( + backend_sizes = { + "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096), + "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096), + } + env_var, default_size = backend_sizes.get( self.server_args.attention_backend, (None, None) ) self.truncation_align_size = ( @@ -846,7 +849,6 @@ class Scheduler( disable=server_args.disable_radix_cache, enable_kv_cache_events=self.enable_kv_cache_events, eviction_policy=server_args.radix_eviction_policy, - enable_deterministic_inference=server_args.enable_deterministic_inference, is_eagle=self.spec_algorithm.is_eagle(), ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 05404dc2b..f82594330 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -1,7 +1,5 @@ from __future__ import annotations -from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE - """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -187,7 +185,6 @@ class RadixCache(BasePrefixCache): disable: bool = False, enable_kv_cache_events: bool = False, eviction_policy: str = "lru", - enable_deterministic_inference: bool = False, is_eagle: bool = False, ): self.req_to_token_pool = req_to_token_pool @@ -196,8 +193,6 @@ class RadixCache(BasePrefixCache): self.disable = disable self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue = [] - self.enable_deterministic_inference = enable_deterministic_inference - self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE self.is_eagle = is_eagle if self.token_to_kv_pool_allocator: @@ -239,9 +234,7 @@ class RadixCache(BasePrefixCache): self.protected_size_ = 0 self._record_all_cleared_event() - def match_prefix( - self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs - ) -> MatchResult: + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: """Find the longest cached prefix of ``key`` in the radix tree. The logical namespace for prefix matching is determined by both the @@ -302,9 +295,7 @@ class RadixCache(BasePrefixCache): if len(key) == 0: return empty_match_result() - value, last_node = self._match_prefix_helper( - self.root_node, key, is_cache_unfinished=is_cache_unfinished - ) + value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) else: @@ -435,8 +426,7 @@ class RadixCache(BasePrefixCache): # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( - RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key), - is_cache_unfinished=True, + RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) ) self.req_to_token_pool.write( (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), @@ -552,58 +542,16 @@ class RadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper( - self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool - ): + def _match_prefix_helper(self, node: TreeNode, key: RadixKey): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) value = [] - align_split_size = ( - not is_cache_unfinished and self.enable_deterministic_inference - ) - match_history = [node] if align_split_size else None - - if align_split_size and len(key) < self.split_size: - # fast path: directly return the root node if the split point is 0 - return value, node - - # use the access history to first find a split point at split_size and then return the value and node at that point. - def reconstruct_at_split_point(match_history, value_len): - # reverse the search process to find the last node right above the split_size, split here - split_point = value_len // self.split_size * self.split_size - # rebuild value form history - value = [] - current_value_len = 0 - node = match_history[0] # this is the root node - for idx, node in enumerate(match_history): - match_len = len(node.value) - if current_value_len + match_len > split_point: - # split the node at the desired split point - node = self._split_node( - node.key, node, split_point - current_value_len - ) - value.append(node.value) - return value, node - elif current_value_len + match_len == split_point: - if idx != 0: - value.append(node.value) - return value, node - current_value_len += match_len - if idx != 0: - # the root node always has empty value, skip - value.append(node.value) - # return the root node as the corresponding node doesn't exist yet - # and the previously computed node is not at the split boundary - return [], match_history[0] - while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] child.last_access_time = time.monotonic() prefix_len = self.key_match_fn(child.key, key) - if align_split_size: - match_history.append(child) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) @@ -617,13 +565,6 @@ class RadixCache(BasePrefixCache): if len(key): child_key = self.get_child_key_fn(key) - if align_split_size: - value_len = sum(map(len, value)) - value, node = reconstruct_at_split_point(match_history, value_len) - assert ( - sum(map(len, value)) % self.split_size == 0 - ), "The value length is not aligned with the split size" - return value, node def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 94179cf7f..864432496 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1411,6 +1411,13 @@ class ServerArgs: f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." ) + # Currently, only FA3 supports radix cache. Support for other backends is in progress + if self.attention_backend != "fa3": + self.disable_radix_cache = True + logger.warning( + f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future." + ) + # Check TP size if self.tp_size > 1: os.environ["NCCL_ALGO"] = "allreduce:tree" diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 5f735fe88..084065b61 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3441,16 +3441,3 @@ def cached_triton_kernel(key_fn=None): return CachedKernel(fn, key_fn) return decorator - - -DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096 -DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = { - "flashinfer": ( - "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", - DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE, - ), - "triton": ( - "SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", - DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE, - ), -} diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py index 1df868315..1175a35e5 100644 --- a/python/sglang/test/test_deterministic.py +++ b/python/sglang/test/test_deterministic.py @@ -273,10 +273,9 @@ def test_deterministic(args): elif args.test_mode == "prefix": # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix. - len_prefix = [1, 8000, 10000, 12500] + len_prefix = [1, 511, 2048, 4097] num_prompts = len(len_prefix) outputs = {i: [] for i in range(4)} - assert all(i <= len(LONG_PROMPT) for i in len_prefix) prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)] for i in range(args.n_start, args.n_start + args.n_trials): batch_size = i