From 8c298031d57ee63e595cf5d59b4a49c4e872bf82 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Fri, 4 Jul 2025 13:48:11 +0800 Subject: [PATCH] refactor llama4 dp attention logic (#7729) --- python/sglang/srt/models/llama4.py | 77 +++++++++++++----------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 082b97ae0..1bb6fcc12 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -27,9 +27,8 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( - dp_gather_partial, - dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, @@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module): bias_o_proj=False, prefix=add_prefix("self_attn", prefix), ) - is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0 + self.config = config + is_moe_layer = self._is_moe_layer(layer_id) + is_previous_moe_layer = self._is_moe_layer(layer_id - 1) + if is_moe_layer: self.feed_forward = Llama4MoE( config=config, @@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=is_moe_layer, + is_previous_layer_sparse=is_previous_moe_layer, + ) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + + def _is_moe_layer(self, layer_id: int) -> bool: + return (layer_id + 1) % self.config.interleave_moe_layer_step == 0 + def forward( self, positions: torch.Tensor, @@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - if hidden_states.shape[0] == 0: - residual = hidden_states - else: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if hidden_states.shape[0] != 0: hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) - # Gather - if get_tensor_model_parallel_world_size() > 1: - # all gather and all reduce - if self.local_dp_size != 1: - if self.attn_tp_rank == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) - else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) - else: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) # Fully Connected hidden_states = self.feed_forward(hidden_states, forward_batch) - - # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter - # Scatter - if self.local_dp_size != 1: - # important: forward batch.gathered_buffer is used both after scatter and after gather. - # be careful about this! - hidden_states, global_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - dp_scatter(hidden_states, global_hidden_states, forward_batch) + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) return hidden_states, residual