simplify the control logic for using shared experts fusion (#5504)
This commit is contained in:
@@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.n_share_experts_fusion = (
|
||||
global_server_args_dict["n_share_experts_fusion"]
|
||||
if global_server_args_dict["n_share_experts_fusion"] is not None
|
||||
else 0
|
||||
)
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
@@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
@@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
return final_hidden_states
|
||||
|
||||
def _forward_shared_experts(self, hidden_states):
|
||||
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
if self.n_share_experts_fusion == 0:
|
||||
return self.shared_experts(hidden_states)
|
||||
else:
|
||||
return None
|
||||
@@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
global_server_args_dict.get("disable_shared_experts_fusion", False)
|
||||
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||
or self.config.n_routed_experts != 256
|
||||
or self.config.routed_scaling_factor != 2.5
|
||||
):
|
||||
self.n_share_experts_fusion = None
|
||||
global_server_args_dict["n_share_experts_fusion"] = None
|
||||
logger.info(
|
||||
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
||||
)
|
||||
elif self.n_share_experts_fusion is None:
|
||||
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
||||
self.n_share_experts_fusion = self.tp_size
|
||||
logger.info(
|
||||
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
|
||||
)
|
||||
if self.n_share_experts_fusion > 0:
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||
or self.config.n_routed_experts != 256
|
||||
):
|
||||
self.n_share_experts_fusion = 0
|
||||
global_server_args_dict["n_share_experts_fusion"] = 0
|
||||
logger.info(
|
||||
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.n_share_experts_fusion == self.tp_size
|
||||
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
@@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
||||
if self.n_share_experts_fusion > 0:
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config.get_name() == "w8a8_int8":
|
||||
@@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.n_share_experts_fusion
|
||||
if self.n_share_experts_fusion is not None
|
||||
else 0
|
||||
),
|
||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
Reference in New Issue
Block a user