[Feature] Support deterministic inference with FA3 backend (#10651)
This commit is contained in:
@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.sliding_window_size is not None and self.sliding_window_size > -1
|
||||
)
|
||||
|
||||
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
|
||||
# We set nums splits to 1 if deterministic inference is enabled.
|
||||
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
|
||||
self.num_splits = (
|
||||
1 if model_runner.server_args.enable_deterministic_inference else 0
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
o, _ = merge_state_v2_wrapper(
|
||||
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
)
|
||||
o, _ = merge_state_v2_wrapper(
|
||||
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
elif use_local_attn:
|
||||
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
num_splits=self.num_splits,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
o, _ = merge_state_v2(
|
||||
o,
|
||||
|
||||
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
||||
|
||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
|
||||
|
||||
|
||||
# Allow external code to add more choices
|
||||
@@ -998,11 +998,13 @@ class ServerArgs:
|
||||
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
||||
)
|
||||
|
||||
# Check some settings
|
||||
self.disable_radix_cache = True
|
||||
logger.warning(
|
||||
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
|
||||
)
|
||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||
if self.attention_backend != "fa3":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning(
|
||||
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
|
||||
)
|
||||
|
||||
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||
raise ValueError(
|
||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||
|
||||
Reference in New Issue
Block a user