feat: use fa3 mla by default on hopper (#5210)

Co-authored-by: yundai424 <yundai424@gmail.com>
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
Yineng Zhang
2025-04-12 01:09:25 -07:00
committed by GitHub
parent 115ae2e728
commit 57de7c6b5f
3 changed files with 42 additions and 11 deletions

View File

@@ -80,6 +80,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
is_hopper_with_cuda_12_3,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes,
@@ -245,7 +246,16 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.attention_backend = "triton"
if is_hopper_with_cuda_12_3():
if server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
and server_args.speculative_eagle_topk == 1
):
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
@@ -263,6 +273,16 @@ class ModelRunner:
else:
raise ValueError(f"MLA optimization not supported on CPU.")
if (
server_args.attention_backend == "fa3"
and server_args.kv_cache_dtype == "fp8_e5m2"
):
logger.warning(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -889,9 +909,6 @@ class ModelRunner:
"FlashAttention v3 Backend requires SM>=90. "
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)