From 47824c14881b7e8d961b41aea2b6d0fcd3e66759 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:08:41 +0800 Subject: [PATCH] [Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898) --- .../srt/layers/moe/fused_moe_triton/layer.py | 8 +-- .../sglang/srt/layers/quantization/mxfp4.py | 51 +++++++++---------- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/gpt_oss.py | 14 +++-- python/sglang/srt/server_args.py | 22 ++++---- 5 files changed, 48 insertions(+), 48 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ca0c2c5f0..ec702ddb9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module): assert self.quant_method is not None self.quant_config = quant_config + self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get( + "enable_flashinfer_mxfp4_moe", False + ) if ( self.quant_config is not None and self.quant_config.get_name() == "mxfp4" - and ( - get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE") - or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE") - ) + and self.use_enable_flashinfer_mxfp4_moe ): hidden_size = round_up(hidden_size, 256) self.hidden_size = hidden_size diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index db5d23acc..1d5f54deb 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -3,22 +3,20 @@ from __future__ import annotations -import importlib +import importlib.util import logging -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, List, Optional import torch from torch.nn.parameter import Parameter -# from vllm.model_executor.layers.fused_moe import ( -# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, -# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.utils import is_layer_skipped +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -32,11 +30,6 @@ from sglang.srt.utils import ( has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None -# Environment variables for FlashInfer MXFP4 MoE backend -USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false") -USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var( - "SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false" -) if is_flashinfer_available(): # from flashinfer.fused_moe import cutlass_fused_moe @@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig): ): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True) + use_flashinfer = global_server_args_dict.get( + "enable_flashinfer_mxfp4_moe", False + ) + return Mxfp4MoEMethod( + use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer + ) else: raise NotImplementedError("Mxfp4 attention layer is not implemented") return None @@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig): class Mxfp4MoEMethod(FusedMoEMethodBase): - def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True): + def __init__( + self, + use_triton_kernels: bool = True, + with_bias: bool = True, + use_flashinfer: bool = False, + ): super().__init__() self.topk_indices_dtype = None self.use_triton_kernels = use_triton_kernels self.with_bias = with_bias + self.use_flashinfer = use_flashinfer + self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None if torch.cuda.is_available() and has_triton_kernels: @@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling - if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: + if self.use_flashinfer: intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256) hidden_size = round_up(hidden_size, 256) elif is_hip(): @@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_weight_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: + if self.use_flashinfer: logger.info( "Shuffling MoE weights for FlashInfer, it might take a while..." ) @@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation_alpha: Optional[float] = None, swiglu_limit: Optional[float] = None, ) -> torch.Tensor: - if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: - # When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input, - # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, - # which can theoretically improve performance - if USE_FLASHINFER_MXFP4_BF16_MOE: - assert x.dtype == torch.bfloat16 - x_quant = x - x_scale = None - else: - x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + if self.use_flashinfer: + # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) - topk_weights, topk_ids, router_logits = topk_output - top_k = topk_weights.shape[-1] + top_k, router_logits = topk_output trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index da47667bd..dca2cbfb7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "num_reserved_decode_tokens", "weight_loader_disable_mmap", "enable_triton_kernel_moe", + "enable_flashinfer_mxfp4_moe", "enable_multimodal", "enable_symm_mem", "quantization", diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index b523c2e1b..fd9d9441c 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module): f"the number of experts {config.num_local_experts}." ) - self.topk = TopK( - top_k=config.num_experts_per_tok, - renormalize=True, - ) + if global_server_args_dict["enable_flashinfer_mxfp4_moe"]: + self.topk = None + else: + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, + ) + self.top_k = config.num_experts_per_tok experts_type = get_moe_impl_class() extra_kwargs = {} if experts_type.__name__ == "FusedMoE": @@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module): if self.topk is not None: kwargs["topk_output"] = self.topk(hidden_states, router_logits) else: - kwargs["router_logits"] = router_logits + kwargs["topk_output"] = (self.top_k, router_logits) final_hidden_states = self.experts(**kwargs) if self.tp_size > 1: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 27362ad27..391d8e714 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -248,6 +248,7 @@ class ServerArgs: disable_fast_image_processor: bool = False enable_return_hidden_states: bool = False enable_triton_kernel_moe: bool = False + enable_flashinfer_mxfp4_moe: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -476,18 +477,10 @@ class ServerArgs: or self.attention_backend == "triton" ) - # Check if FlashInfer MXFP4 MoE is enabled - from sglang.srt.utils import get_bool_env_var - - USE_FLASHINFER_MXFP4_MOE = get_bool_env_var( - "SGLANG_USE_FLASHINFER_MXFP4_MOE", "false" - ) - USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var( - "SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false" - ) - - # Only enable Triton kernel MoE if FlashInfer is not enabled - if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE): + if is_sm100_supported(): + self.enable_flashinfer_mxfp4_moe = True + self.enable_triton_kernel_moe = False + else: self.enable_triton_kernel_moe = True self.disable_hybrid_swa_memory = True @@ -1846,6 +1839,11 @@ class ServerArgs: action="store_true", help="Use triton moe grouped gemm kernel.", ) + parser.add_argument( + "--enable-flashinfer-mxfp4-moe", + action="store_true", + help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.", + ) # Debug tensor dumps parser.add_argument(