[Feature] Support deterministic inference with FA3 backend (#10651)

This commit is contained in:
Stefan He
2025-09-20 17:50:21 -07:00
committed by GitHub
parent f1d7892318
commit cba0d8c309
2 changed files with 25 additions and 6 deletions

View File

@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self.sliding_window_size is not None and self.sliding_window_size > -1
)
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
)
o, _ = merge_state_v2_wrapper(
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
elif use_local_attn:
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
else:
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
)
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
o, _ = merge_state_v2(
o,

View File

@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
# Allow external code to add more choices
@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.disable_radix_cache = True
logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."