From e9feb4883830c53ed6def9a0ca9c83ebebb96d08 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 22 May 2025 13:03:15 +0800 Subject: [PATCH] =?UTF-8?q?[RL]=20Remove=20the=20w13=20weight=5Fscale=20an?= =?UTF-8?q?d=20input=5Fscale=20for=20UnquantizedEPMoE=E2=80=A6=20(#6308)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/layers/moe/ep_moe/layer.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index d24caaaba..6d629333e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module): # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if ( - param_data[expert_id] != 1 + (shard_id == "w1" or shard_id == "w3") + and param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( @@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): set_weight_attrs(w2_weight, extra_weight_attrs) # scale + layer.register_parameter("w13_input_scale", None) + layer.register_parameter("w13_weight_scale", None) + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) - w13_input_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( ones_tensor, @@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) - w13_weight_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter( ones_tensor, requires_grad=False,