diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index a8b1223cc..5e0931ead 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -187,11 +187,24 @@ class LayerCommunicator: if hidden_states.shape[0] == 0: residual = hidden_states else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if ( + residual is not None + and hasattr(hidden_states, "_sglang_needs_allreduce_fusion") + and hidden_states._sglang_needs_allreduce_fusion + ): + hidden_states, residual = ( + self.input_layernorm.forward_with_allreduce_fusion( + hidden_states, residual + ) + ) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual + ) hidden_states = self._communicate_simple_fn( hidden_states=hidden_states, diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 2ce049359..0cc44be55 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1367,7 +1367,7 @@ class RowParallelLinear(LinearBase): # It does not support additional parameters. param.load_row_parallel_weight(loaded_weight) - def forward(self, input_): + def forward(self, input_, can_fuse_mlp_allreduce=False): if self.input_is_parallel: input_parallel = input_ else: @@ -1382,7 +1382,7 @@ class RowParallelLinear(LinearBase): # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce: output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 18c408bc7..1a5165030 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import ( ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -100,6 +101,7 @@ from sglang.srt.utils import ( get_int_env_var, is_cpu, is_cuda, + is_flashinfer_available, is_hip, is_non_idle_and_non_empty, log_info_on_rank0, @@ -132,6 +134,9 @@ if _is_hip: decode_attention_fwd_grouped_rope, ) +_is_flashinfer_available = is_flashinfer_available() +_is_sm100_supported = is_cuda() and is_sm100_supported() + logger = logging.getLogger(__name__) @@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_batch=None): + def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): if (self.tp_size == 1) and x.shape[0] == 0: return x gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) return x @@ -409,7 +414,10 @@ class DeepseekV2MoE(nn.Module): ] def forward( - self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + can_fuse_mlp_allreduce: bool = False, ) -> torch.Tensor: if not self._enable_deepep_moe: DUAL_STREAM_TOKEN_THRESHOLD = 1024 @@ -418,13 +426,17 @@ class DeepseekV2MoE(nn.Module): and self.num_fused_shared_experts == 0 and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD ): - return self.forward_normal_dual_stream(hidden_states) + return self.forward_normal_dual_stream( + hidden_states, can_fuse_mlp_allreduce + ) else: - return self.forward_normal(hidden_states) + return self.forward_normal(hidden_states, can_fuse_mlp_allreduce) else: return self.forward_deepep(hidden_states, forward_batch) - def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward_normal_dual_stream( + self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + ) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) @@ -440,11 +452,13 @@ class DeepseekV2MoE(nn.Module): final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward_normal( + self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj ): @@ -461,7 +475,7 @@ class DeepseekV2MoE(nn.Module): final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -514,7 +528,7 @@ class DeepseekV2MoE(nn.Module): None, # a2_scale True, # is_vnni ) - if self.tp_size > 1: + if self.tp_size > 1 and not self.can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -1818,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module): and layer_id % self.config.moe_layer_freq == 0 ) + def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: + """Check if MLP allreduce can be fused with next layer's add_rmsnorm""" + + if ( + self.layer_id == self.config.num_hidden_layers - 1 + or get_tensor_model_parallel_world_size() <= 1 + ): + return False + + if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False): + return False + + if not _is_sm100_supported or not _is_flashinfer_available: + return False + + if hasattr(forward_batch, "input_ids") and ( + forward_batch.input_ids.shape[0] == 0 + or forward_batch.input_ids.shape[0] > 128 + ): + return False + + return True + def forward( self, positions: torch.Tensor, @@ -1842,12 +1879,27 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states, residual, forward_batch ) - hidden_states = self.mlp(hidden_states, forward_batch) - - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch + can_fuse_mlp_allreduce = ( + self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) + and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle()) + and not self.is_nextn ) + hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce) + + if can_fuse_mlp_allreduce: + hidden_states._sglang_needs_allreduce_fusion = True + + if not can_fuse_mlp_allreduce: + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + if self.enable_dp_attention and self.speculative_algorithm.is_eagle(): + # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks. + # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251). + hidden_states = hidden_states.clone() + return hidden_states, residual def op_comm_prepare_attn(