diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index d24caaaba..6d629333e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module): # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if ( - param_data[expert_id] != 1 + (shard_id == "w1" or shard_id == "w3") + and param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( @@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): set_weight_attrs(w2_weight, extra_weight_attrs) # scale + layer.register_parameter("w13_input_scale", None) + layer.register_parameter("w13_weight_scale", None) + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) - w13_input_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( ones_tensor, @@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) - w13_weight_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter( ones_tensor, requires_grad=False,