From d389bedf72a618e349b7acb0c01ca8852b2f8f9c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 9 Jul 2025 17:04:21 +0800 Subject: [PATCH] [CPU][Qwen3 MoE] Enable fused_topk CPU fusion and enhance FP8 TP padding (#7838) --- python/sglang/srt/layers/moe/topk.py | 8 +++++++- python/sglang/srt/layers/parameter.py | 22 +++++++++++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 9cb6d6a0c..ebb959aba 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -83,13 +83,18 @@ def fused_topk_cpu( gating_output: torch.Tensor, topk: int, renormalize: bool, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): - return torch.ops.sgl_kernel.topk_softmax_cpu( + topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( hidden_states=hidden_states, gating_output=gating_output, topk=topk, renormalize=renormalize, ) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids def fused_topk( @@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available: biased_grouped_topk = biased_grouped_topk_cpu grouped_topk = grouped_topk_cpu fused_topk_native = fused_topk_cpu + fused_topk = fused_topk_cpu else: biased_grouped_topk = biased_grouped_topk_gpu grouped_topk = grouped_topk_gpu diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index d0ba43326..1ea75d70c 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -187,11 +187,27 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data = self.data shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) - if not use_presharded_weights: - loaded_weight = loaded_weight.narrow( - self.output_dim, shard_id * shard_size, shard_size + + if _is_cpu: + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, ) + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + shard_id * shard_size, + self.output_dim, + shard_size, + not use_presharded_weights, + ) + else: + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + assert ( param_data.shape == loaded_weight.shape ), f"{param_data.shape=}, {loaded_weight.shape=}"