diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index fb3021706..5dad481cd 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn: ): return CommunicateSummableTensorPairFn._gather + if ( + (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) + and (residual_input_mode == ScatterMode.TP_ATTN_FULL) + and (output_mode == ScatterMode.SCATTERED) + ): + return CommunicateSummableTensorPairFn._scatter + raise NotImplementedError( f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}" ) @@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn: local_hidden_states, ) return hidden_states, residual + + @staticmethod + def _scatter( + hidden_states: torch.Tensor, + residual: torch.Tensor, + forward_batch: ForwardBatch, + context: CommunicateContext, + ): + assert residual is None, "not yet handled residual!=None" + tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) + hidden_states = tensor_list[context.attn_tp_rank] + return hidden_states, residual diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 330ad5133..fa492277a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module): forward_batch=forward_batch, hidden_states=hidden_states, residual=residual, + input_data_scatter_mode=self.layers[ + normal_num_layers - 1 + ].layer_scatter_modes.layer_output_mode, zero_allocator=zero_allocator, ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fe6b00685..203fc0e82 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -32,7 +32,11 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.communicator import ( + LayerCommunicator, + LayerScatterModes, + ScatterMode, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, attn_tp_reduce_scatter, @@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module): hidden_states, residual = model_forward_maybe_tbo( layers=self.layers, enable_tbo=True, + input_data_scatter_mode=ScatterMode.model_input_output(), positions=positions, forward_batch=forward_batch, hidden_states=hidden_states, diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index afdb5fce0..bb527aaa6 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.communicator import ( + CommunicateContext, + CommunicateSimpleFn, + CommunicateSummableTensorPairFn, + ScatterMode, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms @@ -355,6 +361,7 @@ def model_forward_maybe_tbo( positions: torch.Tensor, forward_batch: ForwardBatch, hidden_states: torch.Tensor, + input_data_scatter_mode: ScatterMode, residual: Optional[torch.Tensor], zero_allocator: Optional[BumpAllocator] = None, ): @@ -365,20 +372,32 @@ def model_forward_maybe_tbo( residual=residual, **(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}), ) + layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode operations_strategy = OperationsStrategy.init_new_tbo( layers, forward_batch.global_forward_mode ) if enable_tbo: - return _model_forward_tbo(inputs, operations_strategy) + return _model_forward_tbo( + inputs=inputs, + operations_strategy=operations_strategy, + input_data_scatter_mode=input_data_scatter_mode, + layer_input_scatter_mode=layer_input_scatter_mode, + ) else: return _model_forward_non_tbo(inputs, operations_strategy) -def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy): - # The attn_tp_size!=1 case is not yet extracted to master - assert get_attention_tp_size() == 1 - - inputs_arr = _model_forward_tbo_split_inputs(**inputs) +def _model_forward_tbo( + inputs, + operations_strategy: OperationsStrategy, + input_data_scatter_mode: ScatterMode, + layer_input_scatter_mode: ScatterMode, +): + inputs_arr = _model_forward_tbo_split_inputs( + **inputs, + input_data_scatter_mode=input_data_scatter_mode, + layer_input_scatter_mode=layer_input_scatter_mode, + ) del inputs with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms): @@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs( residual: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, - zero_allocator: Optional[BumpAllocator] = None, + zero_allocator: Optional[BumpAllocator], + input_data_scatter_mode: ScatterMode, + layer_input_scatter_mode: ScatterMode, +) -> List[Dict]: + tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL + context = CommunicateContext.init_new() + + hidden_states, residual = CommunicateSummableTensorPairFn.execute( + hidden_states_input_mode=input_data_scatter_mode, + residual_input_mode=input_data_scatter_mode, + output_mode=tbo_splitter_scatter_mode, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + context=context, + ) + + inputs_arr = _model_forward_tbo_split_inputs_raw( + hidden_states=hidden_states, + residual=residual, + positions=positions, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + def _post_transform(hidden_states, residual, forward_batch, **kwargs): + hidden_states, residual = CommunicateSummableTensorPairFn.execute( + hidden_states_input_mode=tbo_splitter_scatter_mode, + residual_input_mode=tbo_splitter_scatter_mode, + output_mode=layer_input_scatter_mode, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + context=context, + ) + return dict( + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + **kwargs, + ) + + return [_post_transform(**inputs) for inputs in inputs_arr] + + +def _model_forward_tbo_split_inputs_raw( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: Optional[BumpAllocator], ) -> List[Dict]: return [ dict(