[Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (#6479)
This commit is contained in:
@@ -571,7 +571,7 @@ class Fp8MoEMethod:
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if (
|
||||
get_bool_env_var("CUTLASS_MOE")
|
||||
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||
and self.cutlass_fp8_supported
|
||||
and is_sm100_supported()
|
||||
):
|
||||
@@ -973,7 +973,7 @@ class Fp8MoEMethod:
|
||||
return ret
|
||||
|
||||
if (
|
||||
get_bool_env_var("CUTLASS_MOE")
|
||||
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||
and self.cutlass_fp8_supported
|
||||
and self.block_quant
|
||||
and is_sm100_supported()
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.utils import (
|
||||
get_cuda_version,
|
||||
get_device_capability,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
@@ -35,6 +36,7 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
|
||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||
|
||||
if _is_hip and use_aiter_moe:
|
||||
@@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
|
||||
def cutlass_block_fp8_supported() -> bool:
|
||||
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
||||
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
||||
return False
|
||||
if _is_cuda:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
@@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool:
|
||||
|
||||
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
||||
ENABLE_FLASHINFER_GEMM = (
|
||||
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
|
||||
and is_sm100_supported()
|
||||
and is_flashinfer_available()
|
||||
)
|
||||
if ENABLE_FLASHINFER_GEMM:
|
||||
from flashinfer.gemm import gemm_fp8_nt_groupwise
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
@@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear(
|
||||
shape_supported_by_cutlass = (
|
||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||
)
|
||||
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
||||
if ENABLE_FLASHINFER_GEMM:
|
||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
x_scale_input = x_scale.T.contiguous()
|
||||
weight_scale_input = weight_scale.T.contiguous()
|
||||
output = gemm_fp8_nt_groupwise(
|
||||
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
|
||||
)
|
||||
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user