[fix] fix determine_n_share_experts_fusion (#6118)
This commit is contained in:
@@ -1486,14 +1486,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if self.n_share_experts_fusion > 0:
|
if self.n_share_experts_fusion > 0:
|
||||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
if (
|
if (
|
||||||
self.config.architectures[0] != architecture
|
not _is_cuda
|
||||||
|
or self.config.architectures[0] != architecture
|
||||||
or self.config.n_routed_experts != 256
|
or self.config.n_routed_experts != 256
|
||||||
):
|
):
|
||||||
self.n_share_experts_fusion = 0
|
self.n_share_experts_fusion = 0
|
||||||
global_server_args_dict["n_share_experts_fusion"] = 0
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
@@ -1501,7 +1502,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
|
||||||
elif self.n_share_experts_fusion == 0:
|
elif self.n_share_experts_fusion == 0:
|
||||||
if (
|
if (
|
||||||
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
_is_cuda
|
||||||
|
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
||||||
and self.config.architectures[0] == architecture
|
and self.config.architectures[0] == architecture
|
||||||
and self.config.n_routed_experts == 256
|
and self.config.n_routed_experts == 256
|
||||||
and (not global_server_args_dict["enable_deepep_moe"])
|
and (not global_server_args_dict["enable_deepep_moe"])
|
||||||
|
|||||||
Reference in New Issue
Block a user