diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 42491d0d2..f03b6333f 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn: and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] - and hidden_states.shape[0] <= 128 + and hidden_states.shape[0] <= 2048 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index a800b2a97..023db709c 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager() def ensure_workspace_initialized( - max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False + max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False ): """Ensure workspace is initialized""" if not is_flashinfer_available() or _flashinfer_comm is None: @@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - max_token_num: int = 128, - use_oneshot: bool = True, + max_token_num: int = 2048, + use_oneshot: Optional[bool] = None, trigger_completion_at_end: bool = False, fp32_acc: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 2a9dfda59..9a3104fc2 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase): with use_symmetric_memory(parallel_state.get_tp_group()) as sm: output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) sm.tag(output_parallel) + if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8bde5ac8d..8aa57bbf2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module): ) sm.tag(final_hidden_states) + final_hidden_states = final_hidden_states[ + ..., :origin_hidden_states_dim + ].contiguous() + if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states[..., :origin_hidden_states_dim].contiguous() + return final_hidden_states @classmethod def make_expert_params_mapping( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 04acda746..f59d7f248 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module): self, x, forward_batch=None, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ): if (self.tp_size == 1) and x.shape[0] == 0: @@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj( - x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter + x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter ) return x @@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module): self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: if not self._enable_deepep_moe: @@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module): and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD ): return self.forward_normal_dual_stream( - hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter + hidden_states, should_allreduce_fusion, use_reduce_scatter ) else: return self.forward_normal( - hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter + hidden_states, should_allreduce_fusion, use_reduce_scatter ) else: return self.forward_deepep(hidden_states, forward_batch) @@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module): def forward_normal_dual_stream( self, hidden_states: torch.Tensor, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: @@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: + if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states def forward_normal( self, hidden_states: torch.Tensor, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj ): - return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce) + return self.forward_cpu(hidden_states, should_allreduce_fusion) shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: + if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states def forward_cpu( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + should_allreduce_fusion: bool = False, ) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) @@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module): None, # a2_scale True, # is_vnni ) - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.tp_size > 1 and not should_allreduce_fusion: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module): allow_reduce_scatter=True, ) + self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table() + def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: return is_nextn or ( self.config.n_routed_experts is not None @@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module): ) def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: - """Check if MLP allreduce can be fused with next layer's add_rmsnorm""" + """Check if MLP allreduce can be fused with next layer's residual_rmsnorm""" - if ( - self.layer_id == self.config.num_hidden_layers - 1 - or get_tensor_model_parallel_world_size() <= 1 - ): + batch_size = ( + forward_batch.input_ids.shape[0] + if hasattr(forward_batch, "input_ids") + else 0 + ) + + if batch_size > 128: return False - if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False): - return False - - if not _is_sm100_supported or not _is_flashinfer_available: - return False - - if hasattr(forward_batch, "input_ids") and ( - forward_batch.input_ids.shape[0] == 0 - or forward_batch.input_ids.shape[0] > 128 - ): - return False - - return True + return self._fuse_allreduce_lookup_table.get(batch_size, False) def forward( self, @@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states, residual, forward_batch ) - can_fuse_mlp_allreduce = ( + should_allreduce_fusion = ( self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle()) and not self.is_nextn @@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch ) hidden_states = self.mlp( - hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter + hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter ) - if can_fuse_mlp_allreduce: + if should_allreduce_fusion: hidden_states._sglang_needs_allreduce_fusion = True - if not can_fuse_mlp_allreduce: + if not should_allreduce_fusion: hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) @@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module): ) return output + def _build_fuse_allreduce_lookup_table(self): + static_conditions_met = ( + self.layer_id != self.config.num_hidden_layers - 1 + and get_tensor_model_parallel_world_size() > 1 + and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) + and _is_sm100_supported + and _is_flashinfer_available + ) + + if not static_conditions_met: + return {} + + lookup_table = {} + for batch_size in range(129): # 0 to 128 + is_last_layer = self.layer_id == self.config.num_hidden_layers - 1 + should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer + lookup_table[batch_size] = should_fuse + + return lookup_table + class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 67ef6ca79..7727e9605 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): + def forward(self, x, forward_batch=None, should_allreduce_fusion=False): if (self.tp_size == 1) and x.shape[0] == 0: return x gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce) + x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion) return x @@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): def forward_normal_dual_stream( self, hidden_states: torch.Tensor, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: @@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): if self.ep_size > 1: if ( self.tp_size > 1 - and not can_fuse_mlp_allreduce + and not should_allreduce_fusion and not use_reduce_scatter ): final_hidden_states = tensor_model_parallel_all_reduce( @@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): final_hidden_states += shared_output if ( self.tp_size > 1 - and not can_fuse_mlp_allreduce + and not should_allreduce_fusion and not use_reduce_scatter ): final_hidden_states = tensor_model_parallel_all_reduce( @@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): def forward_normal( self, hidden_states: torch.Tensor, - can_fuse_mlp_allreduce: bool = False, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj ): - return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce) + return self.forward_cpu(hidden_states, should_allreduce_fusion) shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if self.ep_size > 1: - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.tp_size > 1 and not should_allreduce_fusion: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) @@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): else: if shared_output is not None: final_hidden_states += shared_output - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if self.tp_size > 1 and not should_allreduce_fusion: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 6691cf944..90ca08607 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers + +_is_flashinfer_available = is_flashinfer_available() +_is_sm100_supported = is_cuda() and is_sm100_supported() class GptOssConfig(PretrainedConfig): @@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module): ) def forward( - self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + should_allreduce_fusion: bool = False, ) -> torch.Tensor: if not global_server_args_dict["moe_a2a_backend"].is_deepep(): - return self.forward_normal(hidden_states) + return self.forward_normal(hidden_states, should_allreduce_fusion) else: raise Exception("forward_deepep branch not implemented yet") @@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module): if name not in ["correction_bias"] ] - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward_normal( + self, + hidden_states: torch.Tensor, + should_allreduce_fusion: bool = False, + ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module): kwargs["topk_output"] = (self.top_k, router_logits) final_hidden_states = self.experts(**kwargs) - if self.tp_size > 1: + if self.tp_size > 1 and not should_allreduce_fusion: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) ans = final_hidden_states.view(num_tokens, hidden_dim) @@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module): # GptOss all layers are sparse and have no nextn now self.is_layer_sparse = True + self.is_nextn = False is_previous_layer_sparse = True self.layer_scatter_modes = LayerScatterModes.init_new( @@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module): post_attention_layernorm=self.post_attention_layernorm, ) + self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table() + + def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: + """Check if MLP allreduce can be fused with next layer's residual_rmsnorm""" + + batch_size = ( + forward_batch.input_ids.shape[0] + if hasattr(forward_batch, "input_ids") + else 0 + ) + + if batch_size > 128: + return False + + return self._fuse_allreduce_lookup_table.get(batch_size, False) + + def _build_fuse_allreduce_lookup_table(self): + static_conditions_met = ( + self.layer_id != self.config.num_hidden_layers - 1 + and get_tensor_model_parallel_world_size() > 1 + and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) + and _is_sm100_supported + and _is_flashinfer_available + ) + + if not static_conditions_met: + return {} + + lookup_table = {} + for batch_size in range(129): # 0 to 128 + is_last_layer = self.layer_id == self.config.num_hidden_layers - 1 + should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer + lookup_table[batch_size] = should_fuse + + return lookup_table + def forward( self, positions: torch.Tensor, @@ -424,12 +471,21 @@ class GptOssDecoderLayer(nn.Module): hidden_states, residual, forward_batch ) - hidden_states = self.mlp(hidden_states, forward_batch) - - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch + should_allreduce_fusion = ( + self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) + and not self.is_nextn ) + hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion) + + if should_allreduce_fusion: + hidden_states._sglang_needs_allreduce_fusion = True + + if not should_allreduce_fusion: + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + return hidden_states, residual diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 98a0369e6..b4fd46748 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1435,7 +1435,7 @@ class ServerArgs: parser.add_argument( "--enable-flashinfer-allreduce-fusion", action="store_true", - help="Enable FlashInfer allreduce fusion for Add_RMSNorm.", + help="Enable FlashInfer allreduce fusion with Residual RMSNorm.", ) parser.add_argument( "--deepep-mode",