2025-05-20 11:14:48 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from enum import Enum , auto
2025-05-25 08:34:35 +08:00
from functools import partial
from typing import Dict , Optional
2025-05-20 11:14:48 +08:00
2025-08-14 21:14:53 -07:00
import torch
2025-05-20 11:14:48 +08:00
from sglang . srt . distributed import (
get_tensor_model_parallel_world_size ,
tensor_model_parallel_all_reduce ,
)
from sglang . srt . layers . dp_attention import (
2025-07-24 21:36:21 -07:00
attn_tp_all_gather_into_tensor ,
attn_tp_reduce_scatter_tensor ,
2025-05-20 11:14:48 +08:00
dp_gather_partial ,
2025-08-06 16:21:26 -07:00
dp_reduce_scatter_tensor ,
2025-05-20 11:14:48 +08:00
dp_scatter ,
2025-06-20 14:43:11 -07:00
get_attention_dp_size ,
2025-05-20 11:14:48 +08:00
get_attention_tp_rank ,
get_attention_tp_size ,
2025-08-13 21:08:06 -07:00
get_global_dp_buffer ,
get_local_dp_buffer ,
2025-08-20 17:03:08 +08:00
is_dp_attention_enabled ,
2025-05-20 11:14:48 +08:00
)
2025-08-15 22:08:11 -07:00
from sglang . srt . layers . moe import (
get_moe_a2a_backend ,
should_use_flashinfer_cutlass_moe_fp4_allgather ,
)
2025-05-20 11:14:48 +08:00
from sglang . srt . managers . schedule_batch import global_server_args_dict
from sglang . srt . model_executor . forward_batch_info import ForwardBatch
2025-09-04 20:48:53 +08:00
from sglang . srt . utils import (
2025-09-05 06:11:22 +08:00
get_bool_env_var ,
2025-09-04 20:48:53 +08:00
is_cuda ,
is_flashinfer_available ,
2025-09-05 06:11:22 +08:00
is_gfx95_supported ,
is_hip ,
2025-09-04 20:48:53 +08:00
is_sm90_supported ,
is_sm100_supported ,
)
2025-07-03 10:36:20 +08:00
_is_flashinfer_available = is_flashinfer_available ( )
2025-09-04 20:48:53 +08:00
_is_sm90_supported = is_cuda ( ) and is_sm90_supported ( )
2025-07-03 10:36:20 +08:00
_is_sm100_supported = is_cuda ( ) and is_sm100_supported ( )
2025-09-05 06:11:22 +08:00
_use_aiter = get_bool_env_var ( " SGLANG_USE_AITER " ) and is_hip ( )
_is_gfx95_supported = is_gfx95_supported ( )
if _use_aiter and _is_gfx95_supported :
from sglang . srt . layers . quantization . rocm_mxfp4_utils import fused_rms_mxfp4_quant
2025-05-20 11:14:48 +08:00
2025-08-20 17:03:08 +08:00
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
2025-05-20 11:14:48 +08:00
class ScatterMode ( Enum ) :
2025-05-27 11:24:17 +08:00
"""
Suppose we have TP = 4 , DP = 2 , enable - dp - attention , and the system handles seq a , b , c , d
Model input / output : [ ab , ab , cd , cd ] for four ranks respectively
SCATTERED : [ a , b , c , d ]
TP_ATTN_FULL : [ ab , ab , cd , cd ] , i . e . all ranks inside a TP attn group have full data of the group
FULL : [ abcd , abcd , abcd , abcd ]
"""
2025-05-20 11:14:48 +08:00
SCATTERED = auto ( )
TP_ATTN_FULL = auto ( )
FULL = auto ( )
2025-05-27 11:24:17 +08:00
@staticmethod
2025-05-27 12:22:59 -07:00
def model_input_output ( ) :
""" The scatter mode for model forward pass input and output data """
2025-05-27 11:24:17 +08:00
return ScatterMode . TP_ATTN_FULL
2025-05-20 11:14:48 +08:00
@dataclass
class _LayerModeComputationContext :
num_layers : int
layer_id : int
is_layer_sparse : bool
is_previous_layer_sparse : Optional [ bool ]
def previous_layer ( self ) :
assert self . is_previous_layer_sparse is not None
return _LayerModeComputationContext (
layer_id = self . layer_id - 1 ,
is_layer_sparse = self . is_previous_layer_sparse ,
is_previous_layer_sparse = None ,
num_layers = self . num_layers ,
)
@dataclass
class LayerScatterModes :
layer_input_mode : ScatterMode
attn_mode : ScatterMode
# Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
mlp_mode : ScatterMode
middle_residual_mode : ScatterMode
layer_output_mode : ScatterMode
@classmethod
def init_new ( cls , * * kwargs ) :
context = _LayerModeComputationContext ( * * kwargs )
return cls (
layer_input_mode = cls . _compute_layer_input_mode ( context ) ,
attn_mode = ScatterMode . TP_ATTN_FULL ,
mlp_mode = cls . _compute_mlp_mode ( context ) ,
middle_residual_mode = cls . _compute_middle_residual_mode ( context ) ,
layer_output_mode = cls . _compute_layer_output_mode ( context ) ,
)
@classmethod
def _compute_layer_input_mode ( cls , context : _LayerModeComputationContext ) :
if context . layer_id == 0 :
2025-05-27 12:22:59 -07:00
return ScatterMode . model_input_output ( )
2025-05-20 11:14:48 +08:00
return cls . _compute_layer_output_mode ( context . previous_layer ( ) )
@classmethod
def _compute_mlp_mode ( cls , context : _LayerModeComputationContext ) :
if context . is_layer_sparse :
return (
ScatterMode . SCATTERED
2025-08-15 22:08:11 -07:00
if (
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not get_moe_a2a_backend ( ) . is_none ( )
or should_use_flashinfer_cutlass_moe_fp4_allgather ( )
)
2025-05-20 11:14:48 +08:00
else ScatterMode . FULL
)
else :
return (
ScatterMode . SCATTERED
if enable_moe_dense_fully_dp ( )
else ScatterMode . FULL
)
@classmethod
def _compute_middle_residual_mode ( cls , context : _LayerModeComputationContext ) :
mlp_mode = cls . _compute_mlp_mode ( context )
if mlp_mode == ScatterMode . SCATTERED :
return ScatterMode . SCATTERED
if mlp_mode == ScatterMode . FULL :
return ScatterMode . TP_ATTN_FULL
raise NotImplementedError
@classmethod
def _compute_layer_output_mode ( cls , context : _LayerModeComputationContext ) :
mlp_mode = cls . _compute_mlp_mode ( context )
if context . layer_id == context . num_layers - 1 :
2025-05-27 12:22:59 -07:00
return ScatterMode . model_input_output ( )
2025-05-20 11:14:48 +08:00
if mlp_mode == ScatterMode . SCATTERED :
return ScatterMode . SCATTERED
if mlp_mode == ScatterMode . FULL :
return ScatterMode . TP_ATTN_FULL
raise NotImplementedError
def enable_moe_dense_fully_dp ( ) :
return global_server_args_dict [ " moe_dense_tp_size " ] == 1
class LayerCommunicator :
def __init__ (
self ,
layer_scatter_modes : LayerScatterModes ,
input_layernorm : torch . nn . Module ,
post_attention_layernorm : torch . nn . Module ,
2025-08-06 16:21:26 -07:00
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter : bool = False ,
2025-08-20 17:03:08 +08:00
is_last_layer : bool = False ,
2025-05-20 11:14:48 +08:00
) :
self . layer_scatter_modes = layer_scatter_modes
self . input_layernorm = input_layernorm
self . post_attention_layernorm = post_attention_layernorm
2025-08-06 16:21:26 -07:00
self . allow_reduce_scatter = allow_reduce_scatter
2025-08-20 17:03:08 +08:00
self . is_last_layer = is_last_layer
2025-05-20 11:14:48 +08:00
2025-05-27 11:24:17 +08:00
self . _context = CommunicateContext . init_new ( )
self . _communicate_simple_fn = CommunicateSimpleFn . get_fn (
2025-05-25 08:34:35 +08:00
input_mode = self . layer_scatter_modes . layer_input_mode ,
output_mode = self . layer_scatter_modes . attn_mode ,
context = self . _context ,
)
self . _communicate_with_all_reduce_and_layer_norm_fn = (
2025-05-27 11:24:17 +08:00
CommunicateWithAllReduceAndLayerNormFn . get_fn (
2025-05-25 08:34:35 +08:00
hidden_states_input_mode = self . layer_scatter_modes . attn_mode ,
residual_input_mode = self . layer_scatter_modes . layer_input_mode ,
hidden_states_output_mode = self . layer_scatter_modes . mlp_mode ,
residual_output_mode = self . layer_scatter_modes . middle_residual_mode ,
context = self . _context ,
)
)
self . _communicate_summable_tensor_pair_fn = (
2025-05-27 11:24:17 +08:00
CommunicateSummableTensorPairFn . get_fn (
2025-05-25 08:34:35 +08:00
hidden_states_input_mode = self . layer_scatter_modes . mlp_mode ,
residual_input_mode = self . layer_scatter_modes . middle_residual_mode ,
output_mode = self . layer_scatter_modes . layer_output_mode ,
context = self . _context ,
)
)
2025-05-20 11:14:48 +08:00
def prepare_attn (
self ,
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-09-05 06:11:22 +08:00
qaunt_format : str = " " ,
2025-05-20 11:14:48 +08:00
) :
if hidden_states . shape [ 0 ] == 0 :
residual = hidden_states
else :
2025-07-11 06:12:39 +08:00
if (
residual is not None
and hasattr ( hidden_states , " _sglang_needs_allreduce_fusion " )
and hidden_states . _sglang_needs_allreduce_fusion
) :
hidden_states , residual = (
self . input_layernorm . forward_with_allreduce_fusion (
hidden_states , residual
)
)
2025-05-20 11:14:48 +08:00
else :
2025-07-11 06:12:39 +08:00
if residual is None :
residual = hidden_states
2025-09-05 06:11:22 +08:00
if _use_aiter and _is_gfx95_supported and ( " mxfp4 " in qaunt_format ) :
hidden_states = fused_rms_mxfp4_quant (
hidden_states ,
self . input_layernorm . weight ,
self . input_layernorm . variance_epsilon ,
None ,
None ,
None ,
None ,
)
else :
hidden_states = self . input_layernorm ( hidden_states )
2025-07-11 06:12:39 +08:00
else :
2025-09-05 06:11:22 +08:00
if _use_aiter and _is_gfx95_supported and ( " mxfp4 " in qaunt_format ) :
hidden_states , residual = fused_rms_mxfp4_quant (
hidden_states ,
self . input_layernorm . weight ,
self . input_layernorm . variance_epsilon ,
None ,
None ,
None ,
residual ,
)
else :
hidden_states , residual = self . input_layernorm (
hidden_states , residual
)
2025-05-20 11:14:48 +08:00
2025-05-25 08:34:35 +08:00
hidden_states = self . _communicate_simple_fn (
2025-05-20 11:14:48 +08:00
hidden_states = hidden_states ,
forward_batch = forward_batch ,
2025-05-25 08:34:35 +08:00
context = self . _context ,
2025-05-20 11:14:48 +08:00
)
return hidden_states , residual
def prepare_mlp (
self ,
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
) :
2025-05-25 08:34:35 +08:00
return self . _communicate_with_all_reduce_and_layer_norm_fn (
2025-05-20 11:14:48 +08:00
hidden_states = hidden_states ,
residual = residual ,
forward_batch = forward_batch ,
layernorm = self . post_attention_layernorm ,
2025-05-25 08:34:35 +08:00
context = self . _context ,
2025-05-20 11:14:48 +08:00
)
def postprocess_layer (
self ,
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
) :
2025-05-25 08:34:35 +08:00
return self . _communicate_summable_tensor_pair_fn (
2025-05-20 11:14:48 +08:00
hidden_states = hidden_states ,
residual = residual ,
forward_batch = forward_batch ,
2025-05-25 08:34:35 +08:00
context = self . _context ,
2025-08-06 16:21:26 -07:00
allow_reduce_scatter = self . allow_reduce_scatter ,
)
def should_use_reduce_scatter ( self , forward_batch : ForwardBatch ) :
return (
self . allow_reduce_scatter
and self . _communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn . _scatter_hidden_states
and forward_batch . dp_padding_mode . is_max_len ( )
2025-05-20 11:14:48 +08:00
)
2025-08-20 17:03:08 +08:00
def should_fuse_mlp_allreduce_with_next_layer (
self , forward_batch : ForwardBatch
) - > bool :
speculative_algo = global_server_args_dict . get ( " speculative_algorithm " , None )
if (
is_dp_attention_enabled ( )
and speculative_algo is not None
and speculative_algo . is_eagle ( )
) :
return False
batch_size = (
forward_batch . input_ids . shape [ 0 ]
if hasattr ( forward_batch , " input_ids " )
else 0
)
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE :
return False
static_conditions_met = (
( not self . is_last_layer )
and ( self . _context . tp_size > 1 )
and global_server_args_dict . get ( " enable_flashinfer_allreduce_fusion " , False )
and _is_flashinfer_available
)
if not static_conditions_met :
return False
return (
batch_size > 0
and batch_size < = FUSE_ALLREDUCE_MAX_BATCH_SIZE
and ( not self . is_last_layer )
)
2025-05-20 11:14:48 +08:00
@dataclass
2025-05-27 11:24:17 +08:00
class CommunicateContext :
2025-06-15 21:08:56 -07:00
process_group_sizes : Dict [ ScatterMode , int ]
2025-05-20 11:14:48 +08:00
attn_tp_rank : int
attn_tp_size : int
2025-06-20 14:43:11 -07:00
attn_dp_size : int
2025-05-20 11:14:48 +08:00
tp_size : int
2025-06-15 21:08:56 -07:00
def is_same_group_size ( self , a : ScatterMode , b : ScatterMode ) :
2025-05-20 11:14:48 +08:00
return self . process_group_sizes [ a ] == self . process_group_sizes [ b ]
2025-05-27 11:24:17 +08:00
@classmethod
def init_new ( cls ) :
attn_tp_rank = get_attention_tp_rank ( )
attn_tp_size = get_attention_tp_size ( )
2025-06-20 14:43:11 -07:00
attn_dp_size = get_attention_dp_size ( )
2025-05-27 11:24:17 +08:00
tp_size = get_tensor_model_parallel_world_size ( )
process_group_sizes = {
ScatterMode . SCATTERED : 1 ,
ScatterMode . TP_ATTN_FULL : attn_tp_size ,
2025-06-15 21:08:56 -07:00
# TODO: support --moe-dense-tp-size > 1
2025-05-27 11:24:17 +08:00
ScatterMode . FULL : tp_size ,
}
return cls (
process_group_sizes = process_group_sizes ,
attn_tp_rank = attn_tp_rank ,
attn_tp_size = attn_tp_size ,
2025-06-20 14:43:11 -07:00
attn_dp_size = attn_dp_size ,
2025-05-27 11:24:17 +08:00
tp_size = tp_size ,
)
2025-05-20 11:14:48 +08:00
2025-05-27 11:24:17 +08:00
class CommunicateSimpleFn :
2025-05-25 08:34:35 +08:00
@staticmethod
def get_fn (
input_mode : ScatterMode ,
output_mode : ScatterMode ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) :
2025-05-20 11:14:48 +08:00
if context . is_same_group_size ( input_mode , output_mode ) :
2025-05-27 11:24:17 +08:00
return CommunicateSimpleFn . _trivial
2025-05-20 11:14:48 +08:00
if ( input_mode == ScatterMode . SCATTERED ) and (
output_mode == ScatterMode . TP_ATTN_FULL
) :
2025-05-27 11:24:17 +08:00
return CommunicateSimpleFn . _scattered_to_tp_attn_full
2025-05-20 11:14:48 +08:00
raise NotImplementedError ( f " { input_mode =} { output_mode =} " )
2025-05-25 08:34:35 +08:00
@staticmethod
def _trivial (
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) - > torch . Tensor :
return hidden_states
@staticmethod
def _scattered_to_tp_attn_full (
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) - > torch . Tensor :
hidden_states , local_hidden_states = (
2025-08-13 21:08:06 -07:00
get_local_dp_buffer ( ) ,
2025-05-25 08:34:35 +08:00
hidden_states ,
)
2025-07-24 21:36:21 -07:00
attn_tp_all_gather_into_tensor (
hidden_states ,
2025-05-25 08:34:35 +08:00
local_hidden_states ,
)
return hidden_states
2025-05-27 11:24:17 +08:00
class CommunicateWithAllReduceAndLayerNormFn :
2025-05-20 11:14:48 +08:00
""" Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm
"""
2025-05-25 08:34:35 +08:00
@staticmethod
def get_fn (
hidden_states_input_mode : ScatterMode ,
residual_input_mode : ScatterMode ,
hidden_states_output_mode : ScatterMode ,
residual_output_mode : ScatterMode ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) :
2025-05-20 11:14:48 +08:00
if (
context . is_same_group_size (
hidden_states_input_mode , hidden_states_output_mode
)
and context . is_same_group_size ( residual_input_mode , residual_output_mode )
and context . attn_tp_size == 1
) :
2025-05-27 11:24:17 +08:00
return CommunicateWithAllReduceAndLayerNormFn . _simple
2025-05-20 11:14:48 +08:00
if (
( hidden_states_input_mode == ScatterMode . TP_ATTN_FULL )
2025-06-15 21:08:56 -07:00
and (
residual_input_mode in [ ScatterMode . SCATTERED , ScatterMode . TP_ATTN_FULL ]
)
2025-05-20 11:14:48 +08:00
and ( hidden_states_output_mode == ScatterMode . FULL )
and ( residual_output_mode == ScatterMode . TP_ATTN_FULL )
) :
2025-06-15 21:08:56 -07:00
return partial (
CommunicateWithAllReduceAndLayerNormFn . _gather_hidden_states_and_residual ,
residual_input_mode = residual_input_mode ,
)
2025-05-20 11:14:48 +08:00
if (
( hidden_states_input_mode == ScatterMode . TP_ATTN_FULL )
and (
residual_input_mode in [ ScatterMode . SCATTERED , ScatterMode . TP_ATTN_FULL ]
)
and ( hidden_states_output_mode == ScatterMode . SCATTERED )
and ( residual_output_mode == ScatterMode . SCATTERED )
) :
2025-05-25 08:34:35 +08:00
return partial (
2025-05-27 11:24:17 +08:00
CommunicateWithAllReduceAndLayerNormFn . _scatter_hidden_states_and_residual ,
2025-05-25 08:34:35 +08:00
residual_input_mode = residual_input_mode ,
)
2025-05-20 11:14:48 +08:00
raise NotImplementedError (
2025-08-15 08:29:38 +08:00
f " { hidden_states_input_mode =} { residual_input_mode =} { hidden_states_output_mode =} { residual_output_mode =} "
2025-05-20 11:14:48 +08:00
)
2025-05-25 08:34:35 +08:00
@staticmethod
def _simple (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
layernorm : torch . nn . Module ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) :
# TODO move these `if shape != 0` into LayerNorm itself
if hidden_states . shape [ 0 ] != 0 :
hidden_states , residual = layernorm ( hidden_states , residual )
return hidden_states , residual
@staticmethod
2025-06-15 21:08:56 -07:00
def _gather_hidden_states_and_residual (
2025-05-25 08:34:35 +08:00
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
layernorm : torch . nn . Module ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-06-15 21:08:56 -07:00
* ,
residual_input_mode ,
2025-05-25 08:34:35 +08:00
) :
2025-06-15 21:08:56 -07:00
if residual_input_mode == ScatterMode . SCATTERED and context . attn_tp_size > 1 :
residual , local_residual = (
2025-08-13 21:08:06 -07:00
get_local_dp_buffer ( ) ,
2025-06-15 21:08:56 -07:00
residual ,
)
2025-07-24 21:36:21 -07:00
attn_tp_all_gather_into_tensor ( residual , local_residual )
2025-06-20 14:43:11 -07:00
if context . attn_dp_size != 1 :
2025-05-25 08:34:35 +08:00
if context . attn_tp_rank == 0 :
hidden_states + = residual
2025-08-03 00:53:08 -07:00
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
use_layer_norm_before_gather = context . attn_tp_size == 1
2025-08-09 01:29:29 -07:00
if use_layer_norm_before_gather and hidden_states . shape [ 0 ] != 0 :
residual = hidden_states
hidden_states = layernorm ( hidden_states )
2025-05-25 08:34:35 +08:00
hidden_states , local_hidden_states = (
2025-08-13 21:08:06 -07:00
get_global_dp_buffer ( ) ,
2025-05-25 08:34:35 +08:00
hidden_states ,
)
dp_gather_partial ( hidden_states , local_hidden_states , forward_batch )
2025-08-03 00:53:08 -07:00
if not use_layer_norm_before_gather :
dp_scatter ( residual , hidden_states , forward_batch )
if hidden_states . shape [ 0 ] != 0 :
hidden_states = layernorm ( hidden_states )
2025-05-25 08:34:35 +08:00
else :
2025-07-09 02:47:17 +08:00
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
2025-07-03 10:36:20 +08:00
if (
2025-09-04 20:48:53 +08:00
( _is_sm100_supported or _is_sm90_supported )
2025-07-03 10:36:20 +08:00
and _is_flashinfer_available
and hasattr ( layernorm , " forward_with_allreduce_fusion " )
and global_server_args_dict [ " enable_flashinfer_allreduce_fusion " ]
2025-09-04 20:48:53 +08:00
and hidden_states . shape [ 0 ] < = 4096
2025-07-03 10:36:20 +08:00
) :
hidden_states , residual = layernorm . forward_with_allreduce_fusion (
hidden_states , residual
)
else :
hidden_states = tensor_model_parallel_all_reduce ( hidden_states )
hidden_states , residual = layernorm ( hidden_states , residual )
2025-05-25 08:34:35 +08:00
return hidden_states , residual
@staticmethod
def _scatter_hidden_states_and_residual (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
layernorm : torch . nn . Module ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
* ,
residual_input_mode ,
) :
2025-07-24 21:36:21 -07:00
input_hidden_states = hidden_states
hidden_states = hidden_states . tensor_split ( context . attn_tp_size ) [
context . attn_tp_rank
]
attn_tp_reduce_scatter_tensor ( hidden_states , input_hidden_states )
2025-05-25 08:34:35 +08:00
if residual_input_mode == ScatterMode . TP_ATTN_FULL :
residual = residual . tensor_split ( context . attn_tp_size ) [ context . attn_tp_rank ]
if hidden_states . shape [ 0 ] != 0 :
hidden_states , residual = layernorm ( hidden_states , residual )
return hidden_states , residual
2025-05-20 11:14:48 +08:00
2025-05-27 11:24:17 +08:00
class CommunicateSummableTensorPairFn :
""" It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed. """
@classmethod
def execute (
cls ,
hidden_states_input_mode ,
residual_input_mode ,
output_mode ,
context ,
* * kwargs ,
) :
return cls . get_fn (
hidden_states_input_mode = hidden_states_input_mode ,
residual_input_mode = residual_input_mode ,
output_mode = output_mode ,
context = context ,
) ( context = context , * * kwargs )
2025-05-20 11:14:48 +08:00
2025-05-25 08:34:35 +08:00
@staticmethod
def get_fn (
hidden_states_input_mode : ScatterMode ,
residual_input_mode : ScatterMode ,
output_mode : ScatterMode ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-05-25 08:34:35 +08:00
) :
2025-05-20 11:14:48 +08:00
if context . is_same_group_size (
hidden_states_input_mode , output_mode
) and context . is_same_group_size ( residual_input_mode , output_mode ) :
2025-05-27 11:24:17 +08:00
return CommunicateSummableTensorPairFn . _trivial
2025-05-20 11:14:48 +08:00
if (
( hidden_states_input_mode == ScatterMode . FULL )
and ( residual_input_mode == ScatterMode . TP_ATTN_FULL )
and ( output_mode == ScatterMode . TP_ATTN_FULL )
) :
2025-05-27 11:24:17 +08:00
return CommunicateSummableTensorPairFn . _scatter_hidden_states
2025-05-20 11:14:48 +08:00
if (
( hidden_states_input_mode == ScatterMode . SCATTERED )
and ( residual_input_mode == ScatterMode . SCATTERED )
and ( output_mode == ScatterMode . TP_ATTN_FULL )
) :
2025-05-27 11:24:17 +08:00
return CommunicateSummableTensorPairFn . _gather
2025-05-20 11:14:48 +08:00
2025-05-27 11:28:12 +08:00
if (
( hidden_states_input_mode == ScatterMode . TP_ATTN_FULL )
and ( residual_input_mode == ScatterMode . TP_ATTN_FULL )
and ( output_mode == ScatterMode . SCATTERED )
) :
return CommunicateSummableTensorPairFn . _scatter
2025-05-20 11:14:48 +08:00
raise NotImplementedError (
f " { hidden_states_input_mode =} { residual_input_mode =} { output_mode =} "
)
2025-05-25 08:34:35 +08:00
@staticmethod
def _trivial (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-08-06 16:21:26 -07:00
* * kwargs ,
2025-05-25 08:34:35 +08:00
) :
return hidden_states , residual
@staticmethod
def _scatter_hidden_states (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-08-06 16:21:26 -07:00
allow_reduce_scatter : bool = False ,
2025-05-25 08:34:35 +08:00
) :
hidden_states , global_hidden_states = (
2025-08-13 21:08:06 -07:00
get_local_dp_buffer ( ) ,
2025-05-25 08:34:35 +08:00
hidden_states ,
)
2025-08-06 16:21:26 -07:00
if allow_reduce_scatter and forward_batch . dp_padding_mode . is_max_len ( ) :
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor ( hidden_states , global_hidden_states )
else :
dp_scatter ( hidden_states , global_hidden_states , forward_batch )
2025-05-25 08:34:35 +08:00
return hidden_states , residual
@staticmethod
def _gather (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
2025-05-27 11:24:17 +08:00
context : CommunicateContext ,
2025-08-06 16:21:26 -07:00
* * kwargs ,
2025-05-25 08:34:35 +08:00
) :
hidden_states + = residual
residual = None
hidden_states , local_hidden_states = (
2025-08-13 21:08:06 -07:00
get_local_dp_buffer ( ) ,
2025-05-25 08:34:35 +08:00
hidden_states ,
)
2025-07-24 21:36:21 -07:00
attn_tp_all_gather_into_tensor (
hidden_states ,
2025-05-25 08:34:35 +08:00
local_hidden_states ,
)
return hidden_states , residual
2025-05-27 11:28:12 +08:00
@staticmethod
def _scatter (
hidden_states : torch . Tensor ,
residual : torch . Tensor ,
forward_batch : ForwardBatch ,
context : CommunicateContext ,
) :
assert residual is None , " not yet handled residual!=None "
tensor_list = list ( hidden_states . tensor_split ( context . attn_tp_size ) )
hidden_states = tensor_list [ context . attn_tp_rank ]
return hidden_states , residual