From 288ae41f7ae9203d3c87a07bf6a86f5ac0486de6 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 6 Aug 2025 16:35:07 -0500 Subject: [PATCH] [NVIDIA] Fix num_experts in modelopt_quant (#8811) --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 5 +++++ python/sglang/srt/layers/quantization/modelopt_quant.py | 6 ++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 35f06c6de..2c02a7463 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1063,10 +1063,15 @@ class FlashInferFP4MoE(FusedMoE): gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view( torch.float8_e4m3fn ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, gemm2_weights=self.gemm2_weights_fp4_shuffled.data, gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view( torch.float8_e4m3fn ), + gemm2_bias=None, output1_scale_scalar=self.g1_scale_c.data, output1_scale_gate_scalar=self.g1_alphas.data, output2_scale_scalar=self.g2_alphas.data, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index fca0ee38b..4e2b3a53e 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -764,8 +764,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) # TODO(ch-wan): check if this is needed - layer.num_experts = num_experts - layer.num_local_experts = num_experts layer.intermediate_size_per_partition = intermediate_size_per_partition layer.params_dtype = params_dtype layer.quant_config = self.quant_config @@ -1106,7 +1104,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): layer.w13_weight_scale, ) - print("Applied flashinfer weight processing for both w13 and w2") + logger.info_once("Applied flashinfer weight processing for both w13 and w2") else: # CUTLASS processing - handle w13 and w2 separately @@ -1126,7 +1124,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) # Both flashinfer cutlass and regular cutlass use same processing for w2 - print("Applied weight processing for both w13 and w2") + logger.info_once("Applied weight processing for both w13 and w2") # Set up CUTLASS MoE parameters device = layer.w13_weight.device