feat: use fa3 mla by default on hopper (#5210)
Co-authored-by: yundai424 <yundai424@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
batch_size = len(seqlens_in_batch)
|
||||
device = seqlens_in_batch.device
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# Draft Decode
|
||||
if forward_batch.spec_info is not None:
|
||||
metadata.cache_seqlens_int32 = (
|
||||
@@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else (-1, -1)
|
||||
)
|
||||
k_descale, v_descale = None, None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None
|
||||
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
@@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
k_descale, v_descale = None, None
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
if layer.k_scale is not None:
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
q = q.to(self.kv_cache_dtype)
|
||||
|
||||
if not self.use_mla:
|
||||
@@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
device = seq_lens.device
|
||||
if forward_mode.is_decode():
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||
@@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
if forward_mode.is_decode():
|
||||
if forward_mode.is_decode_or_idle():
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
if spec_info is not None:
|
||||
|
||||
Reference in New Issue
Block a user