[RL] Remove the w13 weight_scale and input_scale for UnquantizedEPMoE… (#6308)

This commit is contained in:
Zilin Zhu
2025-05-22 13:03:15 +08:00
committed by GitHub
parent fc992a09f9
commit e9feb48838

View File

@@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name: if "input_scale" in weight_name:
if ( if (
param_data[expert_id] != 1 (shard_id == "w1" or shard_id == "w3")
and param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5 and (param_data[expert_id] - loaded_weight).abs() > 1e-5
): ):
raise ValueError( raise ValueError(
@@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# scale # scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w13_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter( w2_input_scale = torch.nn.Parameter(
ones_tensor, ones_tensor,
@@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
ones_tensor, ones_tensor,
requires_grad=False, requires_grad=False,