diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 73a9030f7..3f8973830 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_size, get_global_dp_buffer, get_local_dp_buffer, + is_dp_attention_enabled, ) from sglang.srt.layers.moe import ( get_moe_a2a_backend, @@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available _is_flashinfer_available = is_flashinfer_available() _is_sm100_supported = is_cuda() and is_sm100_supported() +FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 + class ScatterMode(Enum): """ @@ -162,11 +165,13 @@ class LayerCommunicator: post_attention_layernorm: torch.nn.Module, # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator. allow_reduce_scatter: bool = False, + is_last_layer: bool = False, ): self.layer_scatter_modes = layer_scatter_modes self.input_layernorm = input_layernorm self.post_attention_layernorm = post_attention_layernorm self.allow_reduce_scatter = allow_reduce_scatter + self.is_last_layer = is_last_layer self._context = CommunicateContext.init_new() self._communicate_simple_fn = CommunicateSimpleFn.get_fn( @@ -264,6 +269,42 @@ class LayerCommunicator: and forward_batch.dp_padding_mode.is_max_len() ) + def should_fuse_mlp_allreduce_with_next_layer( + self, forward_batch: ForwardBatch + ) -> bool: + speculative_algo = global_server_args_dict.get("speculative_algorithm", None) + if ( + is_dp_attention_enabled() + and speculative_algo is not None + and speculative_algo.is_eagle() + ): + return False + + batch_size = ( + forward_batch.input_ids.shape[0] + if hasattr(forward_batch, "input_ids") + else 0 + ) + if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE: + return False + + static_conditions_met = ( + (not self.is_last_layer) + and (self._context.tp_size > 1) + and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) + and _is_sm100_supported + and _is_flashinfer_available + ) + + if not static_conditions_met: + return False + + return ( + batch_size > 0 + and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE + and (not self.is_last_layer) + ) + @dataclass class CommunicateContext: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index eeebe1863..3bec16bfc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module): input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, allow_reduce_scatter=True, + is_last_layer=( + is_nextn or (self.layer_id == self.config.num_hidden_layers - 1) + ), ) - self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table() - def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: return is_nextn or ( self.config.n_routed_experts is not None @@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module): and layer_id % self.config.moe_layer_freq == 0 ) - def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: - """Check if MLP allreduce can be fused with next layer's residual_rmsnorm""" - - batch_size = ( - forward_batch.input_ids.shape[0] - if hasattr(forward_batch, "input_ids") - else 0 - ) - - if batch_size > 128: - return False - - return self._fuse_allreduce_lookup_table.get(batch_size, False) - def forward( self, positions: torch.Tensor, @@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module): ) should_allreduce_fusion = ( - self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) - and not ( - is_dp_attention_enabled() and self.speculative_algorithm.is_eagle() + self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( + forward_batch ) - and not self.is_nextn ) # For DP with padding, reduce scatter can be used instead of all-reduce. @@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module): ) return output - def _build_fuse_allreduce_lookup_table(self): - static_conditions_met = ( - self.layer_id != self.config.num_hidden_layers - 1 - and get_tensor_model_parallel_world_size() > 1 - and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) - and _is_sm100_supported - and _is_flashinfer_available - ) - - if not static_conditions_met: - return {} - - lookup_table = {} - for batch_size in range(129): # 0 to 128 - is_last_layer = self.layer_id == self.config.num_hidden_layers - 1 - should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer - lookup_table[batch_size] = should_fuse - - return lookup_table - class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 93c4bda49..ff34f1eea 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + is_last_layer=( + self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1) + ), ) - self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table() - - def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: - """Check if MLP allreduce can be fused with next layer's residual_rmsnorm""" - - batch_size = ( - forward_batch.input_ids.shape[0] - if hasattr(forward_batch, "input_ids") - else 0 - ) - - if batch_size > 128: - return False - - return self._fuse_allreduce_lookup_table.get(batch_size, False) - - def _build_fuse_allreduce_lookup_table(self): - static_conditions_met = ( - self.layer_id != self.config.num_hidden_layers - 1 - and get_tensor_model_parallel_world_size() > 1 - and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) - and _is_sm100_supported - and _is_flashinfer_available - ) - - if not static_conditions_met: - return {} - - lookup_table = {} - for batch_size in range(129): # 0 to 128 - is_last_layer = self.layer_id == self.config.num_hidden_layers - 1 - should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer - lookup_table[batch_size] = should_fuse - - return lookup_table - def forward( self, positions: torch.Tensor, @@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module): ) should_allreduce_fusion = ( - self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) - and not self.is_nextn + self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( + forward_batch + ) ) hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)