From 5f91c82526c19988101dcd1eddcacf9243fdd148 Mon Sep 17 00:00:00 2001 From: Jianan Ji <72958002+NorthmanPKU@users.noreply.github.com> Date: Fri, 6 Jun 2025 15:57:50 -0400 Subject: [PATCH] [Feature] Support Flashinfer fmha on Blackwell (#6930) --- python/sglang/srt/layers/attention/flashinfer_backend.py | 6 +++++- .../sglang/srt/layers/attention/flashinfer_mla_backend.py | 6 +++++- python/sglang/srt/layers/quantization/fp8.py | 2 +- python/sglang/srt/layers/quantization/fp8_utils.py | 7 +------ python/sglang/srt/layers/utils.py | 6 ++++++ 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index b4cf99dff..876141083 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -25,6 +25,7 @@ from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available, next_power_of_2 @@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend): for _ in range(self.num_wrappers) ] + fmha_backend = "auto" + if is_sm100_supported(): + fmha_backend = "cutlass" self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - self.workspace_buffer, "NHD" + self.workspace_buffer, "NHD", backend=fmha_backend ) # Two wrappers: one for sliding window attention and one for full attention. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 57ad6fc30..918184dfc 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): else: self.q_indptr_decode = q_indptr_decode_buf + fmha_backend = "auto" + if is_sm100_supported(): + fmha_backend = "cutlass" self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - self.workspace_buffer, "NHD" + self.workspace_buffer, "NHD", backend=fmha_backend ) if not self.skip_prefill: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7847326a6..b561b2660 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( cutlass_fp8_supported, dispatch_w8a8_block_fp8_linear, input_to_float8, - is_sm100_supported, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod @@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import ( per_tensor_dequantize, requantize_with_max_scale, ) +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import ( get_bool_env_var, is_cuda, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index e9e663e59..0e1640fcf 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple import torch from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 +from sglang.srt.layers.utils import is_sm100_supported try: from vllm import _custom_ops as ops @@ -83,12 +84,6 @@ def cutlass_fp8_supported(): return False -def is_sm100_supported(device=None) -> bool: - return (torch.cuda.get_device_capability(device)[0] == 10) and ( - torch.version.cuda >= "12.8" - ) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/python/sglang/srt/layers/utils.py b/python/sglang/srt/layers/utils.py index 63e775618..f61b86293 100644 --- a/python/sglang/srt/layers/utils.py +++ b/python/sglang/srt/layers/utils.py @@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity): """ input = args[0] if args else next(iter(kwargs.values())) return (input,) if self.return_tuple else input + + +def is_sm100_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 10) and ( + torch.version.cuda >= "12.8" + )