diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 2e20c01bd..4ef752d75 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, attn_tp_reduce_scatter_tensor, dp_gather_partial, + dp_reduce_scatter_tensor, dp_scatter, get_attention_dp_size, get_attention_tp_rank, @@ -149,10 +150,13 @@ class LayerCommunicator: layer_scatter_modes: LayerScatterModes, input_layernorm: torch.nn.Module, post_attention_layernorm: torch.nn.Module, + # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator. + allow_reduce_scatter: bool = False, ): self.layer_scatter_modes = layer_scatter_modes self.input_layernorm = input_layernorm self.post_attention_layernorm = post_attention_layernorm + self.allow_reduce_scatter = allow_reduce_scatter self._context = CommunicateContext.init_new() self._communicate_simple_fn = CommunicateSimpleFn.get_fn( @@ -239,6 +243,15 @@ class LayerCommunicator: residual=residual, forward_batch=forward_batch, context=self._context, + allow_reduce_scatter=self.allow_reduce_scatter, + ) + + def should_use_reduce_scatter(self, forward_batch: ForwardBatch): + return ( + self.allow_reduce_scatter + and self._communicate_summable_tensor_pair_fn + is CommunicateSummableTensorPairFn._scatter_hidden_states + and forward_batch.dp_padding_mode.is_max_len() ) @@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn: residual: torch.Tensor, forward_batch: ForwardBatch, context: CommunicateContext, + **kwargs, ): return hidden_states, residual @@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn: residual: torch.Tensor, forward_batch: ForwardBatch, context: CommunicateContext, + allow_reduce_scatter: bool = False, ): - # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter - # important: forward batch.gathered_buffer is used both after scatter and after gather. - # be careful about this! hidden_states, global_hidden_states = ( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - dp_scatter(hidden_states, global_hidden_states, forward_batch) + if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): + # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead. + dp_reduce_scatter_tensor(hidden_states, global_hidden_states) + else: + dp_scatter(hidden_states, global_hidden_states, forward_batch) return hidden_states, residual @staticmethod @@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn: residual: torch.Tensor, forward_batch: ForwardBatch, context: CommunicateContext, + **kwargs, ): hidden_states += residual residual = None diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 55db13336..79397cce5 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -12,6 +12,7 @@ import triton.language as tl from sglang.srt.distributed import ( GroupCoordinator, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce, @@ -355,6 +356,17 @@ def dp_scatter( ) +def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + if get_tensor_model_parallel_world_size() == get_attention_dp_size(): + get_tp_group().reduce_scatter_tensor(output, input) + else: + scattered_local_tokens = input.tensor_split( + get_tensor_model_parallel_world_size() + )[get_tensor_model_parallel_rank()] + get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input) + get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens) + + def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): return get_attention_tp_group().reduce_scatter_tensor(output, input) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 782699749..2a9dfda59 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase): # It does not support additional parameters. param.load_row_parallel_weight(loaded_weight) - def forward(self, input_, can_fuse_mlp_allreduce=False): + def forward(self, input_, skip_all_reduce=False): if self.input_is_parallel: input_parallel = input_ else: @@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase): with use_symmetric_memory(parallel_state.get_tp_group()) as sm: output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) sm.tag(output_parallel) - if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6d09f1fdb..4c47f319d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -628,8 +628,10 @@ class ForwardBatch: self.dp_padding_mode = dp_padding_mode if dp_padding_mode.is_max_len(): - # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states, - # where transferred tokens should be padded to the same length. + # when DP gather mode is all gather, we will use + # all_gather_into_tensor to gather hidden states, where transferred + # tokens should be padded to the same length. We will also use + # reduce-scatter instead of all-reduce after MLP. max_num_tokens = max(global_num_tokens) global_num_tokens = [max_num_tokens] * sync_group_size buffer_len = max_num_tokens * sync_group_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 913764b45..04acda746 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): + def forward( + self, + x, + forward_batch=None, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, + ): if (self.tp_size == 1) and x.shape[0] == 0: return x gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) + x, _ = self.down_proj( + x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter + ) return x @@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module): hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None, can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if not self._enable_deepep_moe: DUAL_STREAM_TOKEN_THRESHOLD = 1024 @@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module): and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD ): return self.forward_normal_dual_stream( - hidden_states, can_fuse_mlp_allreduce + hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter ) else: - return self.forward_normal(hidden_states, can_fuse_mlp_allreduce) + return self.forward_normal( + hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter + ) else: return self.forward_deepep(hidden_states, forward_batch) def forward_normal_dual_stream( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states def forward_normal( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: @@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module): and not self.is_nextn ) - hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce) + # For DP with padding, reduce scatter can be used instead of all-reduce. + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + hidden_states = self.mlp( + hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter + ) if can_fuse_mlp_allreduce: hidden_states._sglang_needs_allreduce_fusion = True diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 32cf01362..4744c0c31 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) + x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce) return x