From 57de7c6b5fb3fa99a386ee0500b25cdbf98eb1a2 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 12 Apr 2025 01:09:25 -0700 Subject: [PATCH] feat: use fa3 mla by default on hopper (#5210) Co-authored-by: yundai424 Co-authored-by: hebiao064 --- .../attention/flashattention_backend.py | 19 ++++++++------ .../sglang/srt/model_executor/model_runner.py | 25 ++++++++++++++++--- python/sglang/srt/utils.py | 9 +++++++ 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 76496300c..a2425f1a2 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend): batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_decode_or_idle(): # Draft Decode if forward_batch.spec_info is not None: metadata.cache_seqlens_int32 = ( @@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend): else (-1, -1) ) k_descale, v_descale = None, None - if self.kv_cache_dtype_str != "auto": + # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention + # has corresponding quantization method so that layer.k_scale is not None + if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None: descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) k_descale = layer.k_scale.expand(descale_shape) v_descale = layer.v_scale.expand(descale_shape) @@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend): causal = not layer.is_cross_attention k_descale, v_descale = None, None + # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention + # has corresponding quantization method so that layer.k_scale is not None if self.kv_cache_dtype_str != "auto": - descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) - k_descale = layer.k_scale.expand(descale_shape) - v_descale = layer.v_scale.expand(descale_shape) + if layer.k_scale is not None: + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) q = q.to(self.kv_cache_dtype) if not self.use_mla: @@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend): """Initialize forward metadata for capturing CUDA graph.""" metadata = FlashAttentionMetadata() device = seq_lens.device - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): if spec_info is not None: # Draft Decode metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ @@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend): seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] req_pool_indices = req_pool_indices[:bs] - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): metadata = self.decode_cuda_graph_metadata[bs] if spec_info is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d254cb73f..995c613ce 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -80,6 +80,7 @@ from sglang.srt.utils import ( is_cuda, is_flashinfer_available, is_hip, + is_hopper_with_cuda_12_3, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, @@ -245,7 +246,16 @@ class ModelRunner: "flashinfer" if is_flashinfer_available() else "triton" ) else: - server_args.attention_backend = "triton" + if is_hopper_with_cuda_12_3(): + if server_args.speculative_eagle_topk is None or ( + server_args.speculative_eagle_topk is not None + and server_args.speculative_eagle_topk == 1 + ): + server_args.attention_backend = "fa3" + else: + server_args.attention_backend = "triton" + else: + server_args.attention_backend = "triton" logger.info( f"Attention backend not set. Use {server_args.attention_backend} backend by default." ) @@ -263,6 +273,16 @@ class ModelRunner: else: raise ValueError(f"MLA optimization not supported on CPU.") + if ( + server_args.attention_backend == "fa3" + and server_args.kv_cache_dtype == "fp8_e5m2" + ): + logger.warning( + "FlashAttention3 only supports fp8_e4m3 if using FP8; " + "Setting attention backend to triton." + ) + server_args.attention_backend = "triton" + if server_args.enable_double_sparsity: logger.info( "Double sparsity optimization is turned on. Use triton backend without CUDA graph." @@ -889,9 +909,6 @@ class ModelRunner: "FlashAttention v3 Backend requires SM>=90. " "Please use `--attention-backend flashinfer`." ) - logger.warning( - "FlashAttention v3 Backend is in Beta. FP8 is not supported." - ) from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend, ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d68fa489b..60dd27116 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim): else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +def is_hopper_with_cuda_12_3(): + if not is_cuda(): + return False + is_hopper = torch.cuda.get_device_capability()[0] == 9 + cuda_version = torch.version.cuda.split(".") + is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3 + return is_hopper and is_cuda_compatible