From 0d1e27a0c572fe3e5ecc70b61e0e47c1682ef245 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 8 Aug 2025 15:11:48 +0800 Subject: [PATCH] Better optimization log for gpt-oss model (#8953) --- python/sglang/srt/layers/quantization/mxfp4.py | 9 +++++---- python/sglang/srt/server_args.py | 6 ++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index dc9a208fd..619f0bfc9 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -24,6 +24,7 @@ from sglang.srt.utils import ( is_cuda, is_flashinfer_available, is_hip, + log_info_on_rank0, next_power_of_2, round_up, set_weight_attrs, @@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None if is_flashinfer_available(): - # from flashinfer.fused_moe import cutlass_fused_moe from flashinfer import ( mxfp8_quantize, shuffle_matrix_a, @@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( mx_axis=1, num_warps=num_warps ) - if is_cuda() and torch.cuda.get_device_capability()[0] == 10: + if _is_sm100_supported: constraints = { "is_persistent": True, "epilogue_subtile": 1, @@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer): if self.use_flashinfer: - logger.info( - "Shuffling MoE weights for FlashInfer, it might take a while..." + log_info_on_rank0( + logger, + "Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...", ) layer.gemm1_alpha = Parameter( torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 30a210980..9d1839ff4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -488,8 +488,14 @@ class ServerArgs: if is_sm100_supported() and is_mxfp4_quant_format: self.enable_flashinfer_mxfp4_moe = True self.enable_triton_kernel_moe = False + logger.info( + "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." + ) else: self.enable_triton_kernel_moe = True + logger.info( + "Detected GPT-OSS model, enabling triton_kernels MOE kernel." + ) self.disable_hybrid_swa_memory = True