[Feature] Support FA3 backend for MLA (#4831)

This commit is contained in:
Baizhou Zhang
2025-03-28 18:30:14 -07:00
committed by GitHub
parent ec3ee0289d
commit 20c90be23d
3 changed files with 180 additions and 74 deletions

View File

@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
elif self.attention_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (