[RL] Remove the w13 weight_scale and input_scale for UnquantizedEPMoE… (#6308)
This commit is contained in:
@@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module):
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if (
|
||||
param_data[expert_id] != 1
|
||||
(shard_id == "w1" or shard_id == "w3")
|
||||
and param_data[expert_id] != 1
|
||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# scale
|
||||
layer.register_parameter("w13_input_scale", None)
|
||||
layer.register_parameter("w13_weight_scale", None)
|
||||
|
||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
@@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
|
||||
Reference in New Issue
Block a user