From c178abdabc563564664c935613a99c368f2a7d50 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Sat, 10 May 2025 16:19:09 +0800 Subject: [PATCH] [fix] fix determine_n_share_experts_fusion (#6118) --- python/sglang/srt/models/deepseek_v2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7763fb18c..b297775a1 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1486,14 +1486,15 @@ class DeepseekV2ForCausalLM(nn.Module): if self.n_share_experts_fusion > 0: # Only Deepseek V3/R1 can use shared experts fusion optimization now. if ( - self.config.architectures[0] != architecture + not _is_cuda + or self.config.architectures[0] != architecture or self.config.n_routed_experts != 256 ): self.n_share_experts_fusion = 0 global_server_args_dict["n_share_experts_fusion"] = 0 log_info_on_rank0( logger, - "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.", + "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.", ) else: assert ( @@ -1501,7 +1502,8 @@ class DeepseekV2ForCausalLM(nn.Module): ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace." elif self.n_share_experts_fusion == 0: if ( - torch.cuda.get_device_capability("cuda") >= (9, 0) + _is_cuda + and torch.cuda.get_device_capability("cuda") >= (9, 0) and self.config.architectures[0] == architecture and self.config.n_routed_experts == 256 and (not global_server_args_dict["enable_deepep_moe"])