Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.num_fused_shared_experts = global_server_args_dict[
|
||||
"num_fused_shared_experts"
|
||||
]
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else config.n_shared_experts
|
||||
)
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
layer_id=self.layer_id,
|
||||
@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
def determine_num_fused_shared_experts(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.num_fused_shared_experts = global_server_args_dict[
|
||||
"num_fused_shared_experts"
|
||||
]
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else self.config.n_shared_experts
|
||||
)
|
||||
if self.num_fused_shared_experts > 0:
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
or self.config.n_routed_experts != 256
|
||||
):
|
||||
self.num_fused_shared_experts = 0
|
||||
global_server_args_dict["num_fused_shared_experts"] = 0
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = 1
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.num_fused_shared_experts == self.tp_size
|
||||
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
||||
elif self.num_fused_shared_experts == 0:
|
||||
if (
|
||||
_is_cuda
|
||||
@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and self.config.n_routed_experts == 256
|
||||
and (not global_server_args_dict["enable_deepep_moe"])
|
||||
):
|
||||
self.num_fused_shared_experts = self.tp_size
|
||||
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
|
||||
self.num_fused_shared_experts = self.config.n_shared_experts
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.num_fused_shared_experts > 0:
|
||||
assert self.num_fused_shared_experts == 1
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config is not None:
|
||||
@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
for moe_layer in tqdm(
|
||||
moe_layers,
|
||||
desc=f"Cloning {self.num_fused_shared_experts} "
|
||||
"replicas of the shared expert into MoE",
|
||||
"shared expert into MoE",
|
||||
):
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
for num_repeat in range(self.num_fused_shared_experts):
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + num_repeat}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + 0}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user