diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b26f1e77f..758a50f53 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module): ) if self.attn_tp_size != 1: - if self.input_is_scattered: - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - else: - if self.attn_tp_rank == 0: - hidden_states += residual - tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) - hidden_states = tensor_list[self.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) - residual = hidden_states - if hidden_states.shape[0] != 0: - hidden_states = self.post_attention_layernorm(hidden_states) - else: - if hidden_states.shape[0] != 0: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + attn_tp_reduce_scatter(hidden_states, tensor_list) + if not self.input_is_scattered: + residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank] + + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) if not ( self._enable_moe_dense_fully_dp() @@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module): q_a_proj_name in cached_a_proj and kv_a_proj_name in cached_a_proj ): - q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] fused_weight = torch.cat(