[RL] Remove the w13 weight_scale and input_scale for UnquantizedEPMoE… (#6308)
This commit is contained in:
@@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module):
|
|||||||
# Input scales can be loaded directly and should be equal.
|
# Input scales can be loaded directly and should be equal.
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
if (
|
if (
|
||||||
param_data[expert_id] != 1
|
(shard_id == "w1" or shard_id == "w3")
|
||||||
|
and param_data[expert_id] != 1
|
||||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# scale
|
# scale
|
||||||
|
layer.register_parameter("w13_input_scale", None)
|
||||||
|
layer.register_parameter("w13_weight_scale", None)
|
||||||
|
|
||||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||||
w13_input_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
||||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
w2_input_scale = torch.nn.Parameter(
|
||||||
ones_tensor,
|
ones_tensor,
|
||||||
@@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
ones_tensor,
|
ones_tensor,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user