feat: use fa3 mla by default on hopper (#5210)

Co-authored-by: yundai424 <yundai424@gmail.com>
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
Yineng Zhang
2025-04-12 01:09:25 -07:00
committed by GitHub
parent 115ae2e728
commit 57de7c6b5f
3 changed files with 42 additions and 11 deletions

View File

@@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend):
batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
metadata.cache_seqlens_int32 = (
@@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend):
else (-1, -1)
)
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
@@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend):
causal = not layer.is_cross_attention
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla:
@@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
device = seq_lens.device
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
@@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None: