diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 362fe3ba3..fb3021706 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -37,10 +37,23 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch class ScatterMode(Enum): + """ + Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d + Model input/output: [ab, ab, cd, cd] for four ranks respectively + SCATTERED: [a, b, c, d] + TP_ATTN_FULL: [ab, ab, cd, cd], i.e. all ranks inside a TP attn group have full data of the group + FULL: [abcd, abcd, abcd, abcd] + """ + SCATTERED = auto() TP_ATTN_FULL = auto() FULL = auto() + @staticmethod + def model_input_output(): + """The scatter mode for model forward pass input and output data""" + return ScatterMode.TP_ATTN_FULL + @dataclass class _LayerModeComputationContext: @@ -82,7 +95,7 @@ class LayerScatterModes: @classmethod def _compute_layer_input_mode(cls, context: _LayerModeComputationContext): if context.layer_id == 0: - return ScatterMode.TP_ATTN_FULL + return ScatterMode.model_input_output() return cls._compute_layer_output_mode(context.previous_layer()) @classmethod @@ -113,7 +126,7 @@ class LayerScatterModes: def _compute_layer_output_mode(cls, context: _LayerModeComputationContext): mlp_mode = cls._compute_mlp_mode(context) if context.layer_id == context.num_layers - 1: - return ScatterMode.TP_ATTN_FULL + return ScatterMode.model_input_output() if mlp_mode == ScatterMode.SCATTERED: return ScatterMode.SCATTERED if mlp_mode == ScatterMode.FULL: @@ -136,30 +149,14 @@ class LayerCommunicator: self.input_layernorm = input_layernorm self.post_attention_layernorm = post_attention_layernorm - self.attn_tp_rank = get_attention_tp_rank() - self.attn_tp_size = get_attention_tp_size() - self.local_attn_dp_size = get_local_attention_dp_size() - self.tp_size = get_tensor_model_parallel_world_size() - self.process_group_sizes = { - ScatterMode.SCATTERED: 1, - ScatterMode.TP_ATTN_FULL: self.attn_tp_size, - ScatterMode.FULL: self.tp_size, - } - - self._context = _Context( - process_group_sizes=self.process_group_sizes, - attn_tp_rank=self.attn_tp_rank, - attn_tp_size=self.attn_tp_size, - local_attn_dp_size=self.local_attn_dp_size, - tp_size=self.tp_size, - ) - self._communicate_simple_fn = _CommunicateSimpleFn.get_fn( + self._context = CommunicateContext.init_new() + self._communicate_simple_fn = CommunicateSimpleFn.get_fn( input_mode=self.layer_scatter_modes.layer_input_mode, output_mode=self.layer_scatter_modes.attn_mode, context=self._context, ) self._communicate_with_all_reduce_and_layer_norm_fn = ( - _CommunicateWithAllReduceAndLayerNormFn.get_fn( + CommunicateWithAllReduceAndLayerNormFn.get_fn( hidden_states_input_mode=self.layer_scatter_modes.attn_mode, residual_input_mode=self.layer_scatter_modes.layer_input_mode, hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, @@ -168,7 +165,7 @@ class LayerCommunicator: ) ) self._communicate_summable_tensor_pair_fn = ( - _CommunicateSummableTensorPairFn.get_fn( + CommunicateSummableTensorPairFn.get_fn( hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, residual_input_mode=self.layer_scatter_modes.middle_residual_mode, output_mode=self.layer_scatter_modes.layer_output_mode, @@ -228,7 +225,7 @@ class LayerCommunicator: @dataclass -class _Context: +class CommunicateContext: process_group_sizes: Dict["ScatterMode", int] attn_tp_rank: int attn_tp_size: int @@ -238,21 +235,40 @@ class _Context: def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): return self.process_group_sizes[a] == self.process_group_sizes[b] + @classmethod + def init_new(cls): + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + local_attn_dp_size = get_local_attention_dp_size() + tp_size = get_tensor_model_parallel_world_size() + process_group_sizes = { + ScatterMode.SCATTERED: 1, + ScatterMode.TP_ATTN_FULL: attn_tp_size, + ScatterMode.FULL: tp_size, + } + return cls( + process_group_sizes=process_group_sizes, + attn_tp_rank=attn_tp_rank, + attn_tp_size=attn_tp_size, + local_attn_dp_size=local_attn_dp_size, + tp_size=tp_size, + ) -class _CommunicateSimpleFn: + +class CommunicateSimpleFn: @staticmethod def get_fn( input_mode: ScatterMode, output_mode: ScatterMode, - context: _Context, + context: CommunicateContext, ): if context.is_same_group_size(input_mode, output_mode): - return _CommunicateSimpleFn._trivial + return CommunicateSimpleFn._trivial if (input_mode == ScatterMode.SCATTERED) and ( output_mode == ScatterMode.TP_ATTN_FULL ): - return _CommunicateSimpleFn._scattered_to_tp_attn_full + return CommunicateSimpleFn._scattered_to_tp_attn_full raise NotImplementedError(f"{input_mode=} {output_mode=}") @@ -260,7 +276,7 @@ class _CommunicateSimpleFn: def _trivial( hidden_states: torch.Tensor, forward_batch: ForwardBatch, - context: _Context, + context: CommunicateContext, ) -> torch.Tensor: return hidden_states @@ -268,7 +284,7 @@ class _CommunicateSimpleFn: def _scattered_to_tp_attn_full( hidden_states: torch.Tensor, forward_batch: ForwardBatch, - context: _Context, + context: CommunicateContext, ) -> torch.Tensor: hidden_states, local_hidden_states = ( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], @@ -281,7 +297,7 @@ class _CommunicateSimpleFn: return hidden_states -class _CommunicateWithAllReduceAndLayerNormFn: +class CommunicateWithAllReduceAndLayerNormFn: """Besides communication, needs to 1. All reduce in tp_attn_group on hidden_states 2. Apply layer norm @@ -293,7 +309,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: residual_input_mode: ScatterMode, hidden_states_output_mode: ScatterMode, residual_output_mode: ScatterMode, - context: _Context, + context: CommunicateContext, ): if ( @@ -303,7 +319,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: and context.is_same_group_size(residual_input_mode, residual_output_mode) and context.attn_tp_size == 1 ): - return _CommunicateWithAllReduceAndLayerNormFn._simple + return CommunicateWithAllReduceAndLayerNormFn._simple if ( (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) @@ -311,7 +327,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: and (hidden_states_output_mode == ScatterMode.FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL) ): - return _CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states + return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states if ( (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) @@ -322,7 +338,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: and (residual_output_mode == ScatterMode.SCATTERED) ): return partial( - _CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual, + CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual, residual_input_mode=residual_input_mode, ) @@ -336,7 +352,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: residual: torch.Tensor, forward_batch: ForwardBatch, layernorm: torch.nn.Module, - context: _Context, + context: CommunicateContext, ): # TODO move these `if shape != 0` into LayerNorm itself if hidden_states.shape[0] != 0: @@ -349,7 +365,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: residual: torch.Tensor, forward_batch: ForwardBatch, layernorm: torch.nn.Module, - context: _Context, + context: CommunicateContext, ): if context.local_attn_dp_size != 1: if context.attn_tp_rank == 0: @@ -373,7 +389,7 @@ class _CommunicateWithAllReduceAndLayerNormFn: residual: torch.Tensor, forward_batch: ForwardBatch, layernorm: torch.nn.Module, - context: _Context, + context: CommunicateContext, *, residual_input_mode, ): @@ -387,35 +403,50 @@ class _CommunicateWithAllReduceAndLayerNormFn: return hidden_states, residual -class _CommunicateSummableTensorPairFn: +class CommunicateSummableTensorPairFn: + """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.""" + + @classmethod + def execute( + cls, + hidden_states_input_mode, + residual_input_mode, + output_mode, + context, + **kwargs, + ): + return cls.get_fn( + hidden_states_input_mode=hidden_states_input_mode, + residual_input_mode=residual_input_mode, + output_mode=output_mode, + context=context, + )(context=context, **kwargs) @staticmethod def get_fn( hidden_states_input_mode: ScatterMode, residual_input_mode: ScatterMode, output_mode: ScatterMode, - context: _Context, + context: CommunicateContext, ): - """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.""" - if context.is_same_group_size( hidden_states_input_mode, output_mode ) and context.is_same_group_size(residual_input_mode, output_mode): - return _CommunicateSummableTensorPairFn._trivial + return CommunicateSummableTensorPairFn._trivial if ( (hidden_states_input_mode == ScatterMode.FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.TP_ATTN_FULL) ): - return _CommunicateSummableTensorPairFn._scatter_hidden_states + return CommunicateSummableTensorPairFn._scatter_hidden_states if ( (hidden_states_input_mode == ScatterMode.SCATTERED) and (residual_input_mode == ScatterMode.SCATTERED) and (output_mode == ScatterMode.TP_ATTN_FULL) ): - return _CommunicateSummableTensorPairFn._gather + return CommunicateSummableTensorPairFn._gather raise NotImplementedError( f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}" @@ -426,7 +457,7 @@ class _CommunicateSummableTensorPairFn: hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, - context: _Context, + context: CommunicateContext, ): return hidden_states, residual @@ -435,7 +466,7 @@ class _CommunicateSummableTensorPairFn: hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, - context: _Context, + context: CommunicateContext, ): # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # important: forward batch.gathered_buffer is used both after scatter and after gather. @@ -452,7 +483,7 @@ class _CommunicateSummableTensorPairFn: hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, - context: _Context, + context: CommunicateContext, ): hidden_states += residual residual = None