From 5aff1e9392d0181e580050ae36ee93c31884e4c6 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:04:59 +0800 Subject: [PATCH] Fix Qwen3MoE missing token padding optimization (#6820) --- python/sglang/srt/layers/moe/topk.py | 6 +++--- python/sglang/srt/models/qwen3_moe.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 1041f007d..14c07b642 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -66,6 +66,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, + num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -91,6 +92,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) return topk_weights, topk_ids @@ -363,15 +365,13 @@ def select_experts( renormalize=renormalize, ) elif custom_routing_function is None: - assert ( - num_token_non_padded is None - ), "num_token_non_padded is not yet supported in fused_topk" # Qwen3MOE uses fused_topk topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, + num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) else: diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 78112bb6a..0724ea779 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -193,6 +193,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=self.top_k, use_grouped_topk=False, renormalize=self.renormalize, + num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), @@ -260,6 +261,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=self.top_k, use_grouped_topk=False, renormalize=self.renormalize, + num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ),