From a303325fdb2fe18cec501e1ccd7245b708f1f51b Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 31 Mar 2025 11:10:21 +0800 Subject: [PATCH] Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883) Co-authored-by: ch-wan --- python/sglang/srt/models/deepseek_v2.py | 35 +++++++------------------ 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a27bd1ea3..c89f9feac 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) + assert not ( + self.attn_tp_size != 1 and self.input_is_scattered + ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1" + # Self Attention hidden_states = self.self_attn( positions=positions, @@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch=forward_batch, ) - if self.attn_tp_size != 1 and self.input_is_scattered: - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) - residual, local_residual = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - residual, - ) - tp_all_gather( - list(residual.tensor_split(self.attn_tp_size)), local_residual - ) - # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce @@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: + hidden_states += residual + residual = None hidden_states, local_hidden_states = ( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, @@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module): tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) - residual, local_residual = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - residual, - ) - tp_all_gather( - list(residual.tensor_split(self.attn_tp_size)), local_residual - ) return hidden_states, residual class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__( @@ -1296,7 +1278,10 @@ class DeepseekV2Model(nn.Module): positions, hidden_states, forward_batch, residual ) if not forward_batch.forward_mode.is_idle(): - hidden_states, _ = self.norm(hidden_states, residual) + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states