From 32f2815451f6893424c587c644992cbb7558afa5 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Sun, 3 Aug 2025 00:53:08 -0700 Subject: [PATCH] Do layernorm before allgather for DP attention (#8631) --- python/sglang/srt/layers/communicator.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 6c61675cb..2e20c01bd 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn: if context.attn_dp_size != 1: if context.attn_tp_rank == 0: hidden_states += residual + + # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size) + use_layer_norm_before_gather = context.attn_tp_size == 1 + if use_layer_norm_before_gather: + residual.copy_(hidden_states) + if hidden_states.shape[0] != 0: + hidden_states = layernorm(hidden_states) + hidden_states, local_hidden_states = ( forward_batch.gathered_buffer, hidden_states, ) dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - if hidden_states.shape[0] != 0: - hidden_states = layernorm(hidden_states) + + if not use_layer_norm_before_gather: + dp_scatter(residual, hidden_states, forward_batch) + if hidden_states.shape[0] != 0: + hidden_states = layernorm(hidden_states) else: # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465 # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).