[fix] fix dsv3 weight loader tqdm and simplify shared experts fusion (#7181)
This commit is contained in:
@@ -1033,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||
)
|
||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(attn_output_val, attn_output_scale),
|
||||
(self.w_vc, self.w_scale_v),
|
||||
attn_bmm_output,
|
||||
@@ -2008,101 +2008,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.num_fused_shared_experts > 0:
|
||||
assert self.num_fused_shared_experts == 1
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config is not None:
|
||||
if self.quant_config.get_name() == "w8a8_int8":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
]
|
||||
elif (
|
||||
self.quant_config.get_name() == "fp8"
|
||||
or self.quant_config.get_name() == "blockwise_int8"
|
||||
):
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale_inv",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale_inv",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
elif self.quant_config.get_name() == "awq":
|
||||
suffix_list = [
|
||||
"down_proj.qweight",
|
||||
"down_proj.qzeros",
|
||||
"down_proj.scales",
|
||||
"gate_proj.qweight",
|
||||
"gate_proj.qzeros",
|
||||
"gate_proj.scales",
|
||||
"up_proj.qweight",
|
||||
"up_proj.qzeros",
|
||||
"up_proj.scales",
|
||||
]
|
||||
elif self.quant_config.get_name() == "modelopt_fp4":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"down_proj.weight_scale_2",
|
||||
"down_proj.input_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"gate_proj.weight_scale_2",
|
||||
"gate_proj.input_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
"up_proj.weight_scale_2",
|
||||
"up_proj.input_scale",
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
||||
)
|
||||
else:
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"gate_proj.weight",
|
||||
"up_proj.weight",
|
||||
]
|
||||
names_to_remove = []
|
||||
|
||||
moe_layers = (
|
||||
range(
|
||||
self.config.first_k_dense_replace,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.moe_layer_freq,
|
||||
)
|
||||
if not is_nextn
|
||||
else [nextn_layer_id]
|
||||
)
|
||||
|
||||
for moe_layer in tqdm(
|
||||
moe_layers,
|
||||
desc=f"Cloning {self.num_fused_shared_experts} "
|
||||
"shared expert into MoE",
|
||||
):
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + 0}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -2128,9 +2033,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"hnorm",
|
||||
]
|
||||
|
||||
if self.num_fused_shared_experts > 0:
|
||||
assert self.num_fused_shared_experts == 1
|
||||
logger.info("Shared experts fusion optimization enabled.")
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
weight_names = []
|
||||
for name, loaded_weight in weights:
|
||||
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
||||
name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts}",
|
||||
)
|
||||
|
||||
weight_names.append(name)
|
||||
|
||||
if not is_nextn:
|
||||
|
||||
Reference in New Issue
Block a user