[Feature] Support deterministic inference with FA3 backend (#10651)
This commit is contained in:
@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.sliding_window_size is not None and self.sliding_window_size > -1
|
self.sliding_window_size is not None and self.sliding_window_size > -1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
|
||||||
|
# We set nums splits to 1 if deterministic inference is enabled.
|
||||||
|
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
|
||||||
|
self.num_splits = (
|
||||||
|
1 if model_runner.server_args.enable_deterministic_inference else 0
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn,
|
return_softmax_lse=use_cascade_attn,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
o, _ = merge_state_v2_wrapper(
|
o, _ = merge_state_v2_wrapper(
|
||||||
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn,
|
return_softmax_lse=use_cascade_attn,
|
||||||
|
num_splits=self.num_splits,
|
||||||
)
|
)
|
||||||
if use_cascade_attn:
|
if use_cascade_attn:
|
||||||
o, softmax_lse, *rest = result
|
o, softmax_lse, *rest = result
|
||||||
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
num_splits=self.num_splits,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
o, _ = merge_state_v2_wrapper(
|
o, _ = merge_state_v2_wrapper(
|
||||||
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif use_local_attn:
|
elif use_local_attn:
|
||||||
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn,
|
return_softmax_lse=use_cascade_attn,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if use_cascade_attn:
|
if use_cascade_attn:
|
||||||
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
num_splits=self.num_splits,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
||||||
|
num_splits=self.num_splits,
|
||||||
)
|
)
|
||||||
if use_cascade_attn:
|
if use_cascade_attn:
|
||||||
o, softmax_lse, *rest = result
|
o, softmax_lse, *rest = result
|
||||||
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
num_splits=self.num_splits,
|
||||||
)
|
)
|
||||||
o, _ = merge_state_v2(
|
o, _ = merge_state_v2(
|
||||||
o,
|
o,
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
|||||||
|
|
||||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||||
|
|
||||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
|
||||||
|
|
||||||
|
|
||||||
# Allow external code to add more choices
|
# Allow external code to add more choices
|
||||||
@@ -998,11 +998,13 @@ class ServerArgs:
|
|||||||
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check some settings
|
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||||
|
if self.attention_backend != "fa3":
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
|
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||||
|
|||||||
Reference in New Issue
Block a user