diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 9f7dafd05..bf8c6c1ed 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1187,6 +1187,21 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: w13_input_scale = layer.w13_input_scale.max().to(torch.float32) w2_input_scale = layer.w2_input_scale.max().to(torch.float32) + elif self.enable_flashinfer_cutedsl_moe: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + w2_input_scale = layer.w2_input_scale + + def _slice_scale(w): + assert w.shape == (layer.num_experts,) + assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts + return w[ + layer.moe_ep_rank + * layer.num_local_experts : (layer.moe_ep_rank + 1) + * layer.num_local_experts + ] + + w13_input_scale = _slice_scale(w13_input_scale) + w2_input_scale = _slice_scale(w2_input_scale) else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) w2_input_scale = layer.w2_input_scale