From 4c22897a66ab6f222f60c7aed999fe047598e3c3 Mon Sep 17 00:00:00 2001 From: wxzhoucs <66296518+Misaka9468@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:10:29 +0800 Subject: [PATCH] Feature: support qwen and llama4 reducescatter for dp attention padding (#9101) --- python/sglang/srt/lora/layers.py | 8 ++++++-- python/sglang/srt/models/llama.py | 12 ++++++++++-- python/sglang/srt/models/llama4.py | 19 ++++++++++++++++--- python/sglang/srt/models/qwen2_moe.py | 22 ++++++++++++++++++---- python/sglang/srt/models/qwen3_moe.py | 23 ++++++++++++++++++----- 5 files changed, 68 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 4328a7601..f9a877cd5 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ) return lora_output - def forward(self, input_: torch.Tensor): + def forward(self, input_: torch.Tensor, skip_all_reduce=False): # duplicate the logic in RowParallelLinear if self.base_layer.input_is_parallel: input_parallel = input_ @@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): if self.set_lora: output_parallel = self.apply_lora(output_parallel, input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + if ( + self.base_layer.reduce_results + and self.base_layer.tp_size > 1 + and not skip_all_reduce + ): output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 4efbc48fd..fc0ce930a 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -91,10 +91,18 @@ class LlamaMLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_batch=None): + def forward( + self, + x, + forward_batch=None, + use_reduce_scatter: bool = False, + ): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x, _ = self.down_proj( + x, + skip_all_reduce=use_reduce_scatter, + ) return x diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index c0a2be43d..cf851bd1e 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -131,14 +131,19 @@ class Llama4MoE(nn.Module): reduce_results=False, # We need to do scatter before reduce ) - def forward(self, hidden_states, forward_batch: ForwardBatch): + def forward( + self, + hidden_states, + forward_batch: ForwardBatch, + use_reduce_scatter: bool = False, + ): shared_out, routed_out = self._forward_core( hidden_states, forward_batch.forward_mode ) out_aD = routed_out + shared_out - if self.tp_size > 1: + if self.tp_size > 1 and not use_reduce_scatter: out_aD = tensor_model_parallel_all_reduce(out_aD) return out_aD @@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def _is_moe_layer(self, layer_id: int) -> bool: @@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module): hidden_states, residual, forward_batch ) + # For DP with padding, reduce scatter can be used instead of all-reduce. + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + # Fully Connected - hidden_states = self.feed_forward(hidden_states, forward_batch) + hidden_states = self.feed_forward( + hidden_states, forward_batch, use_reduce_scatter + ) hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index da7936c4d..81cd97c0e 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -108,10 +108,14 @@ class Qwen2MoeMLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward( + self, + x, + use_reduce_scatter: bool = False, + ): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter) return x @@ -176,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward( - self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + use_reduce_scatter: bool = False, ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -194,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1 and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -368,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def forward( @@ -393,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module): hidden_states, residual, forward_batch ) - hidden_states = self.mlp(hidden_states, forward_batch) + # For DP with padding, reduce scatter can be used instead of all-reduce. + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index d7c9290b2..c17402863 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.top_k = config.num_experts_per_tok def forward( - self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if not global_server_args_dict["moe_a2a_backend"].is_deepep(): - return self.forward_normal(hidden_states) + return self.forward_normal(hidden_states, use_reduce_scatter) else: return self.forward_deepep(hidden_states, forward_batch) @@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): if name not in ["correction_bias"] ] - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward_normal( + self, + hidden_states: torch.Tensor, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): router_logits, _ = self.gate(hidden_states) topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) - if self.tp_size > 1: + if self.tp_size > 1 and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def forward( @@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module): hidden_states, residual, forward_batch ) - hidden_states = self.mlp(hidden_states, forward_batch) + # For DP with padding, reduce scatter can be used instead of all-reduce. + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch