[Feature] Support deterministic inference with FA3 backend (#10651)

This commit is contained in:
Stefan He
2025-09-20 17:50:21 -07:00
committed by GitHub
parent f1d7892318
commit cba0d8c309
2 changed files with 25 additions and 6 deletions

View File

@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self.sliding_window_size is not None and self.sliding_window_size > -1 self.sliding_window_size is not None and self.sliding_window_size > -1
) )
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
elif use_local_attn: elif use_local_attn:
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
else: else:
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
if use_cascade_attn: if use_cascade_attn:
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
) )
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
o, _ = merge_state_v2( o, _ = merge_state_v2(
o, o,

View File

@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
# Allow external code to add more choices # Allow external code to add more choices
@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
) )
# Check some settings # Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True self.disable_radix_cache = True
logger.warning( logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future." "Currently radix cache is disabled for deterministic inference. It will be supported in the future."
) )
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError( raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."