From 1344ebc8333df0bb5463a65b2d71f981659e071f Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Fri, 19 Sep 2025 02:36:22 +0800 Subject: [PATCH] support qwen3-next-fp8 deepep (#10622) --- python/sglang/srt/models/qwen2_moe.py | 65 +++++++++++++++++++++++++- python/sglang/srt/models/qwen3_next.py | 37 +++++++++++---- 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0375ac478..f00610454 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -25,12 +25,14 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import ( LayerCommunicator, @@ -50,6 +52,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK @@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module): quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module): bias=False, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, ) self.down_proj = RowParallelLinear( intermediate_size, @@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module): quant_config=quant_config, reduce_results=reduce_results, prefix=add_prefix("down_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, ) if hidden_act != "silu": raise ValueError( @@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.experts = get_moe_impl_class(quant_config)( layer_id=self.layer_id, top_k=config.num_experts_per_tok, - num_experts=config.num_experts, + num_experts=config.num_experts + + global_server_args_dict["ep_num_redundant_experts"], hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, quant_config=quant_config, @@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module): quant_config=quant_config, reduce_results=False, prefix=add_prefix("shared_expert", prefix), + **( + dict(tp_rank=0, tp_size=1) + if get_moe_a2a_backend().is_deepep() + else {} + ), ) else: self.shared_expert = None self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + if get_moe_a2a_backend().is_deepep(): + # TODO: we will support tp < ep in the future + self.ep_size = get_moe_expert_parallel_world_size() + self.num_experts = ( + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + ) + self.top_k = config.num_experts_per_tok + + def get_moe_weights(self): + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ["correction_bias"] + ] + def _forward_shared_experts(self, hidden_states: torch.Tensor): shared_output = None if self.shared_expert is not None: @@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ) return shared_output + def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch): + shared_output = None + if hidden_states.shape[0] > 0: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, + num_token_non_padded=forward_batch.num_token_non_padded, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) + else: + topk_weights, topk_idx, _ = self.topk.empty_topk_output( + hidden_states.device + ) + final_hidden_states = self.experts( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_batch=forward_batch, + ) + + if shared_output is not None: + final_hidden_states.add_(shared_output) + + return final_hidden_states + def _forward_router_experts(self, hidden_states: torch.Tensor): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if get_moe_a2a_backend().is_deepep(): + return self._forward_deepep(hidden_states, forward_batch) + DUAL_STREAM_TOKEN_THRESHOLD = 1024 if ( self.alt_stream is not None diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 6e6a99cf8..2a1b9d48c 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -13,6 +13,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader @@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import ( sharded_weight_loader, ) from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock -from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs +from sglang.srt.utils import ( + LazyValue, + add_prefix, + is_cuda, + is_npu, + make_layers, + set_weight_attrs, +) logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module): residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states, residual = layer( - layer_id=i, - positions=positions, - hidden_states=hidden_states, - residual=residual, - forward_batch=forward_batch, - ) + with get_global_expert_distribution_recorder().with_current_layer(i): + hidden_states, residual = layer( + layer_id=i, + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) if not forward_batch.forward_mode.is_idle(): if residual is None: @@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module): self.lm_head = self.lm_head.float() self.logits_processor = LogitsProcessor(config) + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock) + } + ) + + @property + def routed_experts_weights_of_layer(self): + return self._routed_experts_weights_of_layer.value + @torch.no_grad() def forward( self,