From ac964d2e580ba76889699b9991bd09265cd09add Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:17:00 +0800 Subject: [PATCH] Support global scale in addition to per expert scale for cutedsl moe (#10270) --- .../srt/layers/quantization/modelopt_quant.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 38894f8c9..0ab963396 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -39,7 +39,7 @@ from sglang.srt.layers.quantization.utils import ( requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import is_cuda, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE @@ -74,6 +74,10 @@ except ImportError: # Initialize logger for the module logger = logging.getLogger(__name__) +CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var( + "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true" +) + # Supported activation schemes for the current configuration ACTIVATION_SCHEMES = ["static"] @@ -1190,7 +1194,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w13_input_scale = layer.w13_input_scale.max().to(torch.float32) w2_input_scale = layer.w2_input_scale.max().to(torch.float32) elif self.enable_flashinfer_cutedsl_moe: - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + # All-expert-one-input-scale is mathematically different from default per-expert-input-scale + # Thus we allow users to switch the flag to do thorough testing + if CUTEDSL_MOE_SCALAR_INPUT_SCALE: + w13_input_scale = ( + layer.w13_input_scale.max() + .to(torch.float32) + .repeat(layer.w13_input_scale.shape[0]) + ) + else: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32 + ) + w2_input_scale = layer.w2_input_scale def _slice_scale(w):