From 791b3bfabb61fad34e35f944e2833817d8e6929f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 28 May 2025 16:03:43 -0700 Subject: [PATCH] [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (#6479) --- docs/references/environment_variables.md | 4 ++++ python/sglang/srt/layers/quantization/fp8.py | 4 ++-- .../srt/layers/quantization/fp8_utils.py | 22 +++++++++++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 344ebaca6..27b8b12fc 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -57,6 +57,10 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | +| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` | +| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` | +| `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` | + ## Distributed Computing diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index e43d1f0bb..1f961ed68 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -571,7 +571,7 @@ class Fp8MoEMethod: layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) assert self.quant_config.activation_scheme == "dynamic" if ( - get_bool_env_var("CUTLASS_MOE") + get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_fp8_supported and is_sm100_supported() ): @@ -973,7 +973,7 @@ class Fp8MoEMethod: return ret if ( - get_bool_env_var("CUTLASS_MOE") + get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_fp8_supported and self.block_quant and is_sm100_supported() diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 05e43fe3f..f38284389 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -28,6 +28,7 @@ from sglang.srt.utils import ( get_cuda_version, get_device_capability, is_cuda, + is_flashinfer_available, is_hip, ) @@ -35,6 +36,7 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() + use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") if _is_hip and use_aiter_moe: @@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz( def cutlass_block_fp8_supported() -> bool: - if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"): + if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"): return False if _is_cuda: major, minor = torch.cuda.get_device_capability() @@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool: CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() +ENABLE_FLASHINFER_GEMM = ( + get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM") + and is_sm100_supported() + and is_flashinfer_available() +) +if ENABLE_FLASHINFER_GEMM: + from flashinfer.gemm import gemm_fp8_nt_groupwise def apply_w8a8_block_fp8_linear( @@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear( shape_supported_by_cutlass = ( weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 ) - if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: + if ENABLE_FLASHINFER_GEMM: + q_input, x_scale = sglang_per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + x_scale_input = x_scale.T.contiguous() + weight_scale_input = weight_scale.T.contiguous() + output = gemm_fp8_nt_groupwise( + q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype + ) + elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=True )