[Feature] Support Flashinfer fmha on Blackwell (#6930)
This commit is contained in:
@@ -25,6 +25,7 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||||
@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
for _ in range(self.num_wrappers)
|
for _ in range(self.num_wrappers)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
fmha_backend = "auto"
|
||||||
|
if is_sm100_supported():
|
||||||
|
fmha_backend = "cutlass"
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD"
|
self.workspace_buffer, "NHD", backend=fmha_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
# Two wrappers: one for sliding window attention and one for full attention.
|
# Two wrappers: one for sliding window attention and one for full attention.
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
self.q_indptr_decode = q_indptr_decode_buf
|
self.q_indptr_decode = q_indptr_decode_buf
|
||||||
|
|
||||||
|
fmha_backend = "auto"
|
||||||
|
if is_sm100_supported():
|
||||||
|
fmha_backend = "cutlass"
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD"
|
self.workspace_buffer, "NHD", backend=fmha_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
dispatch_w8a8_block_fp8_linear,
|
dispatch_w8a8_block_fp8_linear,
|
||||||
input_to_float8,
|
input_to_float8,
|
||||||
is_sm100_supported,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -83,12 +84,6 @@ def cutlass_fp8_supported():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_sm100_supported(device=None) -> bool:
|
|
||||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
|
||||||
torch.version.cuda >= "12.8"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
|
|||||||
@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity):
|
|||||||
"""
|
"""
|
||||||
input = args[0] if args else next(iter(kwargs.values()))
|
input = args[0] if args else next(iter(kwargs.values()))
|
||||||
return (input,) if self.return_tuple else input
|
return (input,) if self.return_tuple else input
|
||||||
|
|
||||||
|
|
||||||
|
def is_sm100_supported(device=None) -> bool:
|
||||||
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||||
|
torch.version.cuda >= "12.8"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user