feat: use fa3 mla by default on hopper (#5210)
Co-authored-by: yundai424 <yundai424@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
batch_size = len(seqlens_in_batch)
|
batch_size = len(seqlens_in_batch)
|
||||||
device = seqlens_in_batch.device
|
device = seqlens_in_batch.device
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
if forward_batch.spec_info is not None:
|
if forward_batch.spec_info is not None:
|
||||||
metadata.cache_seqlens_int32 = (
|
metadata.cache_seqlens_int32 = (
|
||||||
@@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
k_descale, v_descale = None, None
|
k_descale, v_descale = None, None
|
||||||
if self.kv_cache_dtype_str != "auto":
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||||
|
# has corresponding quantization method so that layer.k_scale is not None
|
||||||
|
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
|
||||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
k_descale = layer.k_scale.expand(descale_shape)
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
v_descale = layer.v_scale.expand(descale_shape)
|
v_descale = layer.v_scale.expand(descale_shape)
|
||||||
@@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
causal = not layer.is_cross_attention
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
k_descale, v_descale = None, None
|
k_descale, v_descale = None, None
|
||||||
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||||
|
# has corresponding quantization method so that layer.k_scale is not None
|
||||||
if self.kv_cache_dtype_str != "auto":
|
if self.kv_cache_dtype_str != "auto":
|
||||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
if layer.k_scale is not None:
|
||||||
k_descale = layer.k_scale.expand(descale_shape)
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
v_descale = layer.v_scale.expand(descale_shape)
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
|
v_descale = layer.v_scale.expand(descale_shape)
|
||||||
q = q.to(self.kv_cache_dtype)
|
q = q.to(self.kv_cache_dtype)
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
@@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
"""Initialize forward metadata for capturing CUDA graph."""
|
"""Initialize forward metadata for capturing CUDA graph."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
device = seq_lens.device
|
device = seq_lens.device
|
||||||
if forward_mode.is_decode():
|
if forward_mode.is_decode_or_idle():
|
||||||
if spec_info is not None:
|
if spec_info is not None:
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||||
@@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
seq_lens = seq_lens[:bs]
|
seq_lens = seq_lens[:bs]
|
||||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
req_pool_indices = req_pool_indices[:bs]
|
req_pool_indices = req_pool_indices[:bs]
|
||||||
if forward_mode.is_decode():
|
if forward_mode.is_decode_or_idle():
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
if spec_info is not None:
|
if spec_info is not None:
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ from sglang.srt.utils import (
|
|||||||
is_cuda,
|
is_cuda,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
is_hopper_with_cuda_12_3,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
@@ -245,7 +246,16 @@ class ModelRunner:
|
|||||||
"flashinfer" if is_flashinfer_available() else "triton"
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
server_args.attention_backend = "triton"
|
if is_hopper_with_cuda_12_3():
|
||||||
|
if server_args.speculative_eagle_topk is None or (
|
||||||
|
server_args.speculative_eagle_topk is not None
|
||||||
|
and server_args.speculative_eagle_topk == 1
|
||||||
|
):
|
||||||
|
server_args.attention_backend = "fa3"
|
||||||
|
else:
|
||||||
|
server_args.attention_backend = "triton"
|
||||||
|
else:
|
||||||
|
server_args.attention_backend = "triton"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
||||||
)
|
)
|
||||||
@@ -263,6 +273,16 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"MLA optimization not supported on CPU.")
|
raise ValueError(f"MLA optimization not supported on CPU.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
server_args.attention_backend == "fa3"
|
||||||
|
and server_args.kv_cache_dtype == "fp8_e5m2"
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
|
||||||
|
"Setting attention backend to triton."
|
||||||
|
)
|
||||||
|
server_args.attention_backend = "triton"
|
||||||
|
|
||||||
if server_args.enable_double_sparsity:
|
if server_args.enable_double_sparsity:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
||||||
@@ -889,9 +909,6 @@ class ModelRunner:
|
|||||||
"FlashAttention v3 Backend requires SM>=90. "
|
"FlashAttention v3 Backend requires SM>=90. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
logger.warning(
|
|
||||||
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim):
|
|||||||
else:
|
else:
|
||||||
# Use topk for efficiency with larger k values
|
# Use topk for efficiency with larger k values
|
||||||
return torch.topk(values, topk, dim=dim)
|
return torch.topk(values, topk, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def is_hopper_with_cuda_12_3():
|
||||||
|
if not is_cuda():
|
||||||
|
return False
|
||||||
|
is_hopper = torch.cuda.get_device_capability()[0] == 9
|
||||||
|
cuda_version = torch.version.cuda.split(".")
|
||||||
|
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
|
||||||
|
return is_hopper and is_cuda_compatible
|
||||||
|
|||||||
Reference in New Issue
Block a user