[fix] fix determine_num_fused_shared_experts (#7180)
This commit is contained in:
@@ -1709,53 +1709,35 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
def determine_num_fused_shared_experts(
|
def determine_num_fused_shared_experts(
|
||||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = (
|
self.num_fused_shared_experts = 0
|
||||||
0
|
if global_server_args_dict["disable_shared_experts_fusion"]:
|
||||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
return
|
||||||
else self.config.n_shared_experts
|
|
||||||
)
|
|
||||||
if self.num_fused_shared_experts > 0:
|
|
||||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
|
disable_reason = None
|
||||||
if (
|
if (
|
||||||
not _is_cuda
|
not _is_cuda
|
||||||
|
or torch.cuda.get_device_capability("cuda") < (9, 0)
|
||||||
or self.config.architectures[0] != architecture
|
or self.config.architectures[0] != architecture
|
||||||
or self.config.n_routed_experts != 256
|
or self.config.n_routed_experts != 256
|
||||||
|
or self.config.n_shared_experts != 1
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = 0
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
|
||||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
|
||||||
log_info_on_rank0(
|
|
||||||
logger,
|
|
||||||
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
|
||||||
)
|
|
||||||
elif (
|
elif (
|
||||||
global_server_args_dict["enable_deepep_moe"]
|
global_server_args_dict["enable_deepep_moe"]
|
||||||
or global_server_args_dict["enable_ep_moe"]
|
or global_server_args_dict["enable_ep_moe"]
|
||||||
):
|
):
|
||||||
self.num_fused_shared_experts = 0
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||||
|
|
||||||
|
if disable_reason is not None:
|
||||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
"Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||||
)
|
)
|
||||||
elif self.num_fused_shared_experts == 0:
|
return
|
||||||
if (
|
|
||||||
_is_cuda
|
|
||||||
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
|
||||||
and self.config.architectures[0] == architecture
|
|
||||||
and self.config.n_routed_experts == 256
|
|
||||||
and (
|
|
||||||
not (
|
|
||||||
global_server_args_dict["enable_deepep_moe"]
|
|
||||||
or global_server_args_dict["enable_ep_moe"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
self.num_fused_shared_experts = self.config.n_shared_experts
|
self.num_fused_shared_experts = self.config.n_shared_experts
|
||||||
global_server_args_dict["disable_shared_experts_fusion"] = False
|
|
||||||
log_info_on_rank0(
|
|
||||||
logger,
|
|
||||||
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user