From ec15c8360e736e823b50eecfcc9ae793b4636709 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Thu, 4 Sep 2025 20:48:53 +0800 Subject: [PATCH] Optimize Qwen3-moe model by using flashinfer fused allreduce (#9973) Co-authored-by: luoyuan.luo --- python/sglang/srt/layers/communicator.py | 12 ++++-- python/sglang/srt/models/qwen2_moe.py | 5 ++- python/sglang/srt/models/qwen3_moe.py | 47 ++++++++++++++++++++---- 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4e422a360..320e87962 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -42,9 +42,15 @@ from sglang.srt.layers.moe import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported +from sglang.srt.utils import ( + is_cuda, + is_flashinfer_available, + is_sm90_supported, + is_sm100_supported, +) _is_flashinfer_available = is_flashinfer_available() +_is_sm90_supported = is_cuda() and is_sm90_supported() _is_sm100_supported = is_cuda() and is_sm100_supported() FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 @@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn: # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465 # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True). if ( - _is_sm100_supported + (_is_sm100_supported or _is_sm90_supported) and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] - and hidden_states.shape[0] <= 2048 + and hidden_states.shape[0] <= 4096 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 56ac79a7f..194e513ac 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module): def forward( self, x, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter) + x, _ = self.down_proj( + x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter + ) return x diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index fcb45b947..c1c4c3638 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -42,7 +42,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe import ( + get_moe_a2a_backend, + should_use_flashinfer_cutlass_moe_fp4_allgather, +) from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK @@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel -from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty +from sglang.srt.utils import ( + add_prefix, + is_cuda, + is_flashinfer_available, + is_non_idle_and_non_empty, +) Qwen3MoeConfig = None +_is_flashinfer_available = is_flashinfer_available() + logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: if not get_moe_a2a_backend().is_deepep(): - return self.forward_normal(hidden_states, use_reduce_scatter) + return self.forward_normal( + hidden_states, should_allreduce_fusion, use_reduce_scatter + ) else: return self.forward_deepep(hidden_states, forward_batch) @@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def forward_normal( self, hidden_states: torch.Tensor, + should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): router_logits, _ = self.gate(hidden_states) topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) - if self.tp_size > 1 and not use_reduce_scatter: + if ( + self.tp_size > 1 + and not should_allreduce_fusion + and not use_reduce_scatter + and not should_use_flashinfer_cutlass_moe_fp4_allgather() + ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module): input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, allow_reduce_scatter=True, + is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1), ) def forward( @@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module): hidden_states, residual, forward_batch ) + should_allreduce_fusion = ( + self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( + forward_batch + ) + ) + # For DP with padding, reduce scatter can be used instead of all-reduce. use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( forward_batch ) - hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) - - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch + hidden_states = self.mlp( + hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter ) + if should_allreduce_fusion: + hidden_states._sglang_needs_allreduce_fusion = True + else: + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + return hidden_states, residual def op_comm_prepare_attn(