diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 0ccde7c81..be7fed8de 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend): self.sliding_window_size is not None and self.sliding_window_size > -1 ) + # If num_splits == 0, we use a heuristic to automatically determine the number of splits. + # We set nums splits to 1 if deterministic inference is enabled. + # See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details. + self.num_splits = ( + 1 if model_runner.server_args.enable_deterministic_inference else 0 + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() @@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, **kwargs, ) @@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + num_splits=self.num_splits, **kwargs, ) o, _ = merge_state_v2_wrapper( @@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + num_splits=self.num_splits, ) ) o, _ = merge_state_v2_wrapper( @@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend): softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + num_splits=self.num_splits, **kwargs, ) elif use_local_attn: @@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend): softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + num_splits=self.num_splits, **kwargs, ) else: @@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, **kwargs, ) if use_cascade_attn: @@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + num_splits=self.num_splits, **kwargs, ) ) @@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states + num_splits=self.num_splits, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=True, + num_splits=self.num_splits, ) o, _ = merge_state_v2( o, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 331897ae4..816690e8a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] -DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"] +DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"] # Allow external code to add more choices @@ -998,11 +998,13 @@ class ServerArgs: "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." ) - # Check some settings - self.disable_radix_cache = True - logger.warning( - "Currently radix cache is disabled for deterministic inference. It will be supported in the future." - ) + # Currently, only FA3 supports radix cache. Support for other backends is in progress + if self.attention_backend != "fa3": + self.disable_radix_cache = True + logger.warning( + "Currently radix cache is disabled for deterministic inference. It will be supported in the future." + ) + if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: raise ValueError( f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."