[Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (#6479)

This commit is contained in:
Baizhou Zhang
2025-05-28 16:03:43 -07:00
committed by GitHub
parent 31589e177e
commit 791b3bfabb
3 changed files with 26 additions and 4 deletions

View File

@@ -571,7 +571,7 @@ class Fp8MoEMethod:
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
if (
get_bool_env_var("CUTLASS_MOE")
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and is_sm100_supported()
):
@@ -973,7 +973,7 @@ class Fp8MoEMethod:
return ret
if (
get_bool_env_var("CUTLASS_MOE")
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and is_sm100_supported()

View File

@@ -28,6 +28,7 @@ from sglang.srt.utils import (
get_cuda_version,
get_device_capability,
is_cuda,
is_flashinfer_available,
is_hip,
)
@@ -35,6 +36,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe:
@@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False
if _is_cuda:
major, minor = torch.cuda.get_device_capability()
@@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool:
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
ENABLE_FLASHINFER_GEMM = (
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
and is_sm100_supported()
and is_flashinfer_available()
)
if ENABLE_FLASHINFER_GEMM:
from flashinfer.gemm import gemm_fp8_nt_groupwise
def apply_w8a8_block_fp8_linear(
@@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear(
shape_supported_by_cutlass = (
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
)
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
if ENABLE_FLASHINFER_GEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
x_scale_input = x_scale.T.contiguous()
weight_scale_input = weight_scale.T.contiguous()
output = gemm_fp8_nt_groupwise(
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
)
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True
)