[BugFix] Fix combination of MTP and --n-share-experts-fusionwith R1 (#5707)

This commit is contained in:
Yuhong Guo
2025-04-24 21:13:51 +08:00
committed by GitHub
parent c998d04b46
commit 5d93a950ee
2 changed files with 68 additions and 15 deletions

View File

@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_n_share_experts_fusion()
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.n_share_experts_fusion > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
self.config.architectures[0] != "DeepseekV3ForCausalLM"
self.config.architectures[0] != architecture
or self.config.n_routed_experts != 256
):
self.n_share_experts_fusion = 0
@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
elif self.n_share_experts_fusion == 0:
if (
torch.cuda.get_device_capability("cuda") >= (9, 0)
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
and self.config.architectures[0] == architecture
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens