[Feature] Support Flashinfer fmha on Blackwell (#6930)
This commit is contained in:
@@ -25,6 +25,7 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||
@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
|
||||
fmha_backend = "auto"
|
||||
if is_sm100_supported():
|
||||
fmha_backend = "cutlass"
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
self.workspace_buffer, "NHD", backend=fmha_backend
|
||||
)
|
||||
|
||||
# Two wrappers: one for sliding window attention and one for full attention.
|
||||
|
||||
@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
else:
|
||||
self.q_indptr_decode = q_indptr_decode_buf
|
||||
|
||||
fmha_backend = "auto"
|
||||
if is_sm100_supported():
|
||||
fmha_backend = "cutlass"
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
self.workspace_buffer, "NHD", backend=fmha_backend
|
||||
)
|
||||
|
||||
if not self.skip_prefill:
|
||||
|
||||
@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
dispatch_w8a8_block_fp8_linear,
|
||||
input_to_float8,
|
||||
is_sm100_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -83,12 +84,6 @@ def cutlass_fp8_supported():
|
||||
return False
|
||||
|
||||
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
|
||||
@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
input = args[0] if args else next(iter(kwargs.values()))
|
||||
return (input,) if self.return_tuple else input
|
||||
|
||||
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user