[Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (#6479)
This commit is contained in:
@@ -57,6 +57,10 @@ SGLang supports various environment variables that can be used to configure its
|
|||||||
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
|
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
|
||||||
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
|
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
|
||||||
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
|
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
|
||||||
|
| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
|
||||||
|
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
|
||||||
|
| `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` |
|
||||||
|
|
||||||
|
|
||||||
## Distributed Computing
|
## Distributed Computing
|
||||||
|
|
||||||
|
|||||||
@@ -571,7 +571,7 @@ class Fp8MoEMethod:
|
|||||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
if (
|
if (
|
||||||
get_bool_env_var("CUTLASS_MOE")
|
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||||
and self.cutlass_fp8_supported
|
and self.cutlass_fp8_supported
|
||||||
and is_sm100_supported()
|
and is_sm100_supported()
|
||||||
):
|
):
|
||||||
@@ -973,7 +973,7 @@ class Fp8MoEMethod:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
if (
|
if (
|
||||||
get_bool_env_var("CUTLASS_MOE")
|
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||||
and self.cutlass_fp8_supported
|
and self.cutlass_fp8_supported
|
||||||
and self.block_quant
|
and self.block_quant
|
||||||
and is_sm100_supported()
|
and is_sm100_supported()
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.utils import (
|
|||||||
get_cuda_version,
|
get_cuda_version,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
|
|
||||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||||
|
|
||||||
if _is_hip and use_aiter_moe:
|
if _is_hip and use_aiter_moe:
|
||||||
@@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
|
|
||||||
|
|
||||||
def cutlass_block_fp8_supported() -> bool:
|
def cutlass_block_fp8_supported() -> bool:
|
||||||
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
||||||
return False
|
return False
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
||||||
|
ENABLE_FLASHINFER_GEMM = (
|
||||||
|
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
|
||||||
|
and is_sm100_supported()
|
||||||
|
and is_flashinfer_available()
|
||||||
|
)
|
||||||
|
if ENABLE_FLASHINFER_GEMM:
|
||||||
|
from flashinfer.gemm import gemm_fp8_nt_groupwise
|
||||||
|
|
||||||
|
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
@@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
shape_supported_by_cutlass = (
|
shape_supported_by_cutlass = (
|
||||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||||
)
|
)
|
||||||
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
if ENABLE_FLASHINFER_GEMM:
|
||||||
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||||
|
input_2d, block_size[1], column_major_scales=False
|
||||||
|
)
|
||||||
|
x_scale_input = x_scale.T.contiguous()
|
||||||
|
weight_scale_input = weight_scale.T.contiguous()
|
||||||
|
output = gemm_fp8_nt_groupwise(
|
||||||
|
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
|
||||||
|
)
|
||||||
|
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=True
|
input_2d, block_size[1], column_major_scales=True
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user