diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0512fba87..352480f2d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1709,53 +1709,35 @@ class DeepseekV2ForCausalLM(nn.Module): def determine_num_fused_shared_experts( self, architecture: str = "DeepseekV3ForCausalLM" ): - self.num_fused_shared_experts = ( - 0 - if global_server_args_dict["disable_shared_experts_fusion"] - else self.config.n_shared_experts - ) - if self.num_fused_shared_experts > 0: - # Only Deepseek V3/R1 can use shared experts fusion optimization now. - if ( - not _is_cuda - or self.config.architectures[0] != architecture - or self.config.n_routed_experts != 256 - ): - self.num_fused_shared_experts = 0 - global_server_args_dict["disable_shared_experts_fusion"] = True - log_info_on_rank0( - logger, - "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.", - ) - elif ( - global_server_args_dict["enable_deepep_moe"] - or global_server_args_dict["enable_ep_moe"] - ): - self.num_fused_shared_experts = 0 - global_server_args_dict["disable_shared_experts_fusion"] = True - log_info_on_rank0( - logger, - "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.", - ) - elif self.num_fused_shared_experts == 0: - if ( - _is_cuda - and torch.cuda.get_device_capability("cuda") >= (9, 0) - and self.config.architectures[0] == architecture - and self.config.n_routed_experts == 256 - and ( - not ( - global_server_args_dict["enable_deepep_moe"] - or global_server_args_dict["enable_ep_moe"] - ) - ) - ): - self.num_fused_shared_experts = self.config.n_shared_experts - global_server_args_dict["disable_shared_experts_fusion"] = False - log_info_on_rank0( - logger, - "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", - ) + self.num_fused_shared_experts = 0 + if global_server_args_dict["disable_shared_experts_fusion"]: + return + + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + disable_reason = None + if ( + not _is_cuda + or torch.cuda.get_device_capability("cuda") < (9, 0) + or self.config.architectures[0] != architecture + or self.config.n_routed_experts != 256 + or self.config.n_shared_experts != 1 + ): + disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization." + elif ( + global_server_args_dict["enable_deepep_moe"] + or global_server_args_dict["enable_ep_moe"] + ): + disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode." + + if disable_reason is not None: + global_server_args_dict["disable_shared_experts_fusion"] = True + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens