[BugFix] Fix combination of MTP and --n-share-experts-fusionwith R1 (#5707)
This commit is contained in:
@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.determine_n_share_experts_fusion()
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def determine_n_share_experts_fusion(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
if self.n_share_experts_fusion > 0:
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||
self.config.architectures[0] != architecture
|
||||
or self.config.n_routed_experts != 256
|
||||
):
|
||||
self.n_share_experts_fusion = 0
|
||||
@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
elif self.n_share_experts_fusion == 0:
|
||||
if (
|
||||
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
||||
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
and self.config.architectures[0] == architecture
|
||||
and self.config.n_routed_experts == 256
|
||||
and (not global_server_args_dict["enable_deepep_moe"])
|
||||
):
|
||||
@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
||||
)
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user