[Feature] Support FA3 backend for MLA (#4831)
This commit is contained in:
@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
self.attention_backend = global_server_args_dict["attention_backend"]
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
)
|
||||
elif self.attention_backend == "fa3":
|
||||
# Flash Attention: Keep absorbing for all extend/decode
|
||||
return False
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user