Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)

This commit is contained in:
Cheng Wan
2025-06-04 15:53:22 -07:00
committed by GitHub
parent f0f84975f4
commit 81964328b7
22 changed files with 381 additions and 45 deletions

View File

@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
self.layer_id = layer_id
@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else self.config.n_shared_experts
)
if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_routed_experts != 256
):
self.num_fused_shared_experts = 0
global_server_args_dict["num_fused_shared_experts"] = 0
global_server_args_dict["disable_shared_experts_fusion"] = 1
log_info_on_rank0(
logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
)
else:
assert (
self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.num_fused_shared_experts == 0:
if (
_is_cuda
@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
self.num_fused_shared_experts = self.tp_size
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
self.num_fused_shared_experts = self.config.n_shared_experts
global_server_args_dict["disable_shared_experts_fusion"] = 0
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm(
moe_layers,
desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE",
"shared expert into MoE",
):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.num_fused_shared_experts):
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]