From 0374304a2cb6ccec8f5653a0bdda6e1bc057c39b Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 23 Aug 2025 15:38:40 +0800 Subject: [PATCH] Add enable_flashinfer_mxfp4_bf16_moe for higher precision and slower moe backend (#9004) --- .../sglang/srt/layers/quantization/mxfp4.py | 32 ++++++++++++++++--- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/server_args.py | 9 ++++++ 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 4cb28d421..1e46cc868 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import ( ) from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.with_bias = False self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() + self.flashinfer_mxfp4_moe_precision = global_server_args_dict[ + "flashinfer_mxfp4_moe_precision" + ] self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None @@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): from sglang.srt.layers.moe.topk import TopKOutputChecker if self.use_flashinfer: - # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance - x_quant, x_scale = mxfp8_quantize( - x, False, alignment=self.hidden_size - ) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + # When bf16 mode is enabled, we don't need to quantize the input, + # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, + # which can theoretically improve performance + if self.flashinfer_mxfp4_moe_precision == "bf16": + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + + # May be fused later if this code branch is frequently needed + origin_hidden_states_dim = x_quant.shape[-1] + if self.hidden_size != origin_hidden_states_dim: + x_quant = torch.nn.functional.pad( + x_quant, + (0, self.hidden_size - origin_hidden_states_dim), + mode="constant", + value=0.0, + ) + elif self.flashinfer_mxfp4_moe_precision == "default": + x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + else: + raise NotImplementedError + assert x_quant.shape[-1] == self.hidden_size assert TopKOutputChecker.format_is_bypassed(topk_output) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 95ec32999..a35ba0253 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "disable_flashinfer_cutlass_moe_fp4_allgather", "disable_radix_cache", "enable_dp_lm_head", + "flashinfer_mxfp4_moe_precision", "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 83fec562b..d32227390 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -190,6 +190,7 @@ class ServerArgs: "flashinfer_cutlass", "flashinfer_mxfp4", ] = "auto" + flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" ep_num_redundant_experts: int = 0 @@ -1496,10 +1497,18 @@ class ServerArgs: "triton_kernel", "flashinfer_trtllm", "flashinfer_cutlass", + "flashinfer_mxfp4", ], default=ServerArgs.moe_runner_backend, help="Choose the runner backend for MoE.", ) + parser.add_argument( + "--flashinfer-mxfp4-moe-precision", + type=str, + choices=["mxfp4", "bf16"], + default=ServerArgs.flashinfer_mxfp4_moe_precision, + help="Choose the computation precision of flashinfer mxfp4 moe", + ) parser.add_argument( "--enable-flashinfer-allreduce-fusion", action="store_true",