[fix] fix dsv3 weight loader tqdm and simplify shared experts fusion (#7181)
This commit is contained in:
@@ -1033,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attn_bmm_output = attn_output.new_empty(
|
attn_bmm_output = attn_output.new_empty(
|
||||||
(self.num_local_heads, aligned_m, self.v_head_dim)
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||||
)
|
)
|
||||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||||
(attn_output_val, attn_output_scale),
|
(attn_output_val, attn_output_scale),
|
||||||
(self.w_vc, self.w_scale_v),
|
(self.w_vc, self.w_scale_v),
|
||||||
attn_bmm_output,
|
attn_bmm_output,
|
||||||
@@ -2008,101 +2008,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
if self.num_fused_shared_experts > 0:
|
|
||||||
assert self.num_fused_shared_experts == 1
|
|
||||||
weights_list = list(weights)
|
|
||||||
weights_dict = dict(weights_list)
|
|
||||||
if self.quant_config is not None:
|
|
||||||
if self.quant_config.get_name() == "w8a8_int8":
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"down_proj.weight_scale",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"gate_proj.weight_scale",
|
|
||||||
"up_proj.weight",
|
|
||||||
"up_proj.weight_scale",
|
|
||||||
]
|
|
||||||
elif (
|
|
||||||
self.quant_config.get_name() == "fp8"
|
|
||||||
or self.quant_config.get_name() == "blockwise_int8"
|
|
||||||
):
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"down_proj.weight_scale_inv",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"gate_proj.weight_scale_inv",
|
|
||||||
"up_proj.weight",
|
|
||||||
"up_proj.weight_scale_inv",
|
|
||||||
]
|
|
||||||
elif self.quant_config.get_name() == "awq":
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.qweight",
|
|
||||||
"down_proj.qzeros",
|
|
||||||
"down_proj.scales",
|
|
||||||
"gate_proj.qweight",
|
|
||||||
"gate_proj.qzeros",
|
|
||||||
"gate_proj.scales",
|
|
||||||
"up_proj.qweight",
|
|
||||||
"up_proj.qzeros",
|
|
||||||
"up_proj.scales",
|
|
||||||
]
|
|
||||||
elif self.quant_config.get_name() == "modelopt_fp4":
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"down_proj.weight_scale",
|
|
||||||
"down_proj.weight_scale_2",
|
|
||||||
"down_proj.input_scale",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"gate_proj.weight_scale",
|
|
||||||
"gate_proj.weight_scale_2",
|
|
||||||
"gate_proj.input_scale",
|
|
||||||
"up_proj.weight",
|
|
||||||
"up_proj.weight_scale",
|
|
||||||
"up_proj.weight_scale_2",
|
|
||||||
"up_proj.input_scale",
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"up_proj.weight",
|
|
||||||
]
|
|
||||||
names_to_remove = []
|
|
||||||
|
|
||||||
moe_layers = (
|
|
||||||
range(
|
|
||||||
self.config.first_k_dense_replace,
|
|
||||||
self.config.num_hidden_layers,
|
|
||||||
self.config.moe_layer_freq,
|
|
||||||
)
|
|
||||||
if not is_nextn
|
|
||||||
else [nextn_layer_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
for moe_layer in tqdm(
|
|
||||||
moe_layers,
|
|
||||||
desc=f"Cloning {self.num_fused_shared_experts} "
|
|
||||||
"shared expert into MoE",
|
|
||||||
):
|
|
||||||
for suffix in suffix_list:
|
|
||||||
shared_expert_weight_name = (
|
|
||||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
|
||||||
)
|
|
||||||
weights_list.append(
|
|
||||||
(
|
|
||||||
f"model.layers.{moe_layer}."
|
|
||||||
f"mlp.experts."
|
|
||||||
f"{self.config.n_routed_experts + 0}"
|
|
||||||
f".{suffix}",
|
|
||||||
weights_dict[shared_expert_weight_name],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
names_to_remove += [shared_expert_weight_name]
|
|
||||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
@@ -2128,9 +2033,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"hnorm",
|
"hnorm",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.num_fused_shared_experts > 0:
|
||||||
|
assert self.num_fused_shared_experts == 1
|
||||||
|
logger.info("Shared experts fusion optimization enabled.")
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
weight_names = []
|
weight_names = []
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
||||||
|
name = name.replace(
|
||||||
|
"mlp.shared_experts",
|
||||||
|
f"mlp.experts.{self.config.n_routed_experts}",
|
||||||
|
)
|
||||||
|
|
||||||
weight_names.append(name)
|
weight_names.append(name)
|
||||||
|
|
||||||
if not is_nextn:
|
if not is_nextn:
|
||||||
|
|||||||
Reference in New Issue
Block a user