[2/2] Support MHA prefill with FlashAttention 4. (#10937)

Co-authored-by: Hieu Pham <hyhieu@gmail.com>
This commit is contained in:
Lifu Huang
2025-10-08 00:54:20 -07:00
committed by GitHub
parent 97cd38e58d
commit edefab0c64
7 changed files with 34 additions and 23 deletions

View File

@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
assert (
runner.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
return FlashAttentionBackend(runner, fa_impl_ver=4)

View File

@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id