diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a27bd1ea3..c89f9feac 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) + assert not ( + self.attn_tp_size != 1 and self.input_is_scattered + ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1" + # Self Attention hidden_states = self.self_attn( positions=positions, @@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch=forward_batch, ) - if self.attn_tp_size != 1 and self.input_is_scattered: - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - tp_all_gather( - list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states - ) - residual, local_residual = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - residual, - ) - tp_all_gather( - list(residual.tensor_split(self.attn_tp_size)), local_residual - ) - # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce @@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: + hidden_states += residual + residual = None hidden_states, local_hidden_states = ( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, @@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module): tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) - residual, local_residual = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - residual, - ) - tp_all_gather( - list(residual.tensor_split(self.attn_tp_size)), local_residual - ) return hidden_states, residual class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__( @@ -1296,7 +1278,10 @@ class DeepseekV2Model(nn.Module): positions, hidden_states, forward_batch, residual ) if not forward_batch.forward_mode.is_idle(): - hidden_states, _ = self.norm(hidden_states, residual) + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states