[2/2] Support MHA prefill with FlashAttention 4. (#10937)
Co-authored-by: Hieu Pham <hyhieu@gmail.com>
This commit is contained in:
@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
|
||||
|
||||
@register_attention_backend("fa4")
|
||||
def create_flashattention_v4_backend(runner):
|
||||
assert (
|
||||
runner.use_mla_backend
|
||||
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
||||
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||
|
||||
return FlashAttentionBackend(runner, fa_impl_ver=4)
|
||||
|
||||
@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
||||
# Do multi-head attention
|
||||
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
|
||||
Reference in New Issue
Block a user