From c7e85f53787021c2eb84b0d2a59a8e20bd9980a7 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 11 Sep 2025 20:19:17 -0700 Subject: [PATCH] fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local for input scale (#10296) --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 8 +++++++- python/sglang/srt/layers/quantization/modelopt_quant.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 0ea1fa1eb..4ceba1d49 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -503,8 +503,14 @@ class FusedMoE(torch.nn.Module): param.data[:, :dim1, :dim2].copy_(loaded_weight) return + # ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor + load_global_experts = ( + isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + and "input_scale" in weight_name + ) + global_expert_location_metadata = get_global_expert_location_metadata() - if global_expert_location_metadata is None: + if global_expert_location_metadata is None or load_global_experts: self._weight_loader_impl( param=param, loaded_weight=loaded_weight, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 45a4ba559..9f7dafd05 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) w13_input_scale = PerTensorScaleParameter( - data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32), + data=torch.empty(layer.num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( - data=torch.empty(layer.num_local_experts, dtype=torch.float32), + data=torch.empty(layer.num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale)