simplify the control logic for using shared experts fusion (#5504)

This commit is contained in:
Xiaoyu Zhang
2025-04-20 04:17:35 +08:00
committed by GitHub
parent bf86c5e990
commit d58e354472
16 changed files with 69 additions and 54 deletions

View File

@@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = (
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.tp_size > config.n_routed_experts:
raise ValueError(
@@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
@@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states
def _forward_shared_experts(self, hidden_states):
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
@@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict.get("disable_shared_experts_fusion", False)
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5
):
self.n_share_experts_fusion = None
global_server_args_dict["n_share_experts_fusion"] = None
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
elif self.n_share_experts_fusion is None:
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.n_share_experts_fusion = self.tp_size
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
if self.n_share_experts_fusion > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
):
self.n_share_experts_fusion = 0
global_server_args_dict["n_share_experts_fusion"] = 0
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
else:
assert (
self.n_share_experts_fusion == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
@@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
if self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config.get_name() == "w8a8_int8":
@@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)
params_dict = dict(self.named_parameters())