From 8dc191f237550d81eba9137f1e04b499ed8a6742 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 17 May 2025 10:05:33 +0800 Subject: [PATCH] Fix one wasted kernel in DeepSeek and minor refactor (#6316) --- python/sglang/srt/models/deepseek_v2.py | 33 ++++++++----------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b26f1e77f..758a50f53 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module): ) if self.attn_tp_size != 1: - if self.input_is_scattered: - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - else: - if self.attn_tp_rank == 0: - hidden_states += residual - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - residual = hidden_states - if hidden_states.shape[0] != 0: - hidden_states = self.post_attention_layernorm(hidden_states) - else: - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + attn_tp_reduce_scatter(hidden_states, tensor_list) + if not self.input_is_scattered: + residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank] + + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) if not ( self._enable_moe_dense_fully_dp() @@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module): q_a_proj_name in cached_a_proj and kv_a_proj_name in cached_a_proj ): - q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] fused_weight = torch.cat(