From d9dd529854f7404b1e329d552a4f920579e1f820 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 21 Apr 2025 03:46:42 +0800 Subject: [PATCH] enable DeepSeek V3 shared_experts_fusion in sm90 (#5571) --- python/sglang/srt/models/deepseek_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b3bd49173..26c5e617a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1427,6 +1427,18 @@ class DeepseekV2ForCausalLM(nn.Module): assert ( self.n_share_experts_fusion == self.tp_size ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace." + elif self.n_share_experts_fusion == 0: + if ( + torch.cuda.get_device_capability("cuda") >= (9, 0) + and self.config.architectures[0] == "DeepseekV3ForCausalLM" + and self.config.n_routed_experts == 256 + and (not global_server_args_dict["enable_deepep_moe"]) + ): + self.n_share_experts_fusion = self.tp_size + global_server_args_dict["n_share_experts_fusion"] = self.tp_size + logger.info( + "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled." + ) self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix)