From 3c2274fbee3c0bcbbb60d08cff4ad3f7c35a4fbd Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 15 Jun 2025 21:08:56 -0700 Subject: [PATCH] Implement gather before attn (#6378) --- python/sglang/srt/layers/communicator.py | 28 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5dad481cd..a85272150 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -226,13 +226,13 @@ class LayerCommunicator: @dataclass class CommunicateContext: - process_group_sizes: Dict["ScatterMode", int] + process_group_sizes: Dict[ScatterMode, int] attn_tp_rank: int attn_tp_size: int local_attn_dp_size: int tp_size: int - def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): + def is_same_group_size(self, a: ScatterMode, b: ScatterMode): return self.process_group_sizes[a] == self.process_group_sizes[b] @classmethod @@ -244,6 +244,7 @@ class CommunicateContext: process_group_sizes = { ScatterMode.SCATTERED: 1, ScatterMode.TP_ATTN_FULL: attn_tp_size, + # TODO: support --moe-dense-tp-size > 1 ScatterMode.FULL: tp_size, } return cls( @@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn: if ( (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) - and (residual_input_mode == ScatterMode.TP_ATTN_FULL) + and ( + residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL] + ) and (hidden_states_output_mode == ScatterMode.FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL) ): - return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states + return partial( + CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual, + residual_input_mode=residual_input_mode, + ) if ( (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) @@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn: return hidden_states, residual @staticmethod - def _gather_hidden_states( + def _gather_hidden_states_and_residual( hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, layernorm: torch.nn.Module, context: CommunicateContext, + *, + residual_input_mode, ): + if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1: + residual, local_residual = ( + forward_batch.gathered_buffer[ + : forward_batch.input_ids.shape[0] + ].clone(), + residual, + ) + attn_tp_all_gather( + list(residual.tensor_split(context.attn_tp_size)), local_residual + ) if context.local_attn_dp_size != 1: if context.attn_tp_rank == 0: hidden_states += residual