[CPU][Qwen3 MoE] Enable fused_topk CPU fusion and enhance FP8 TP padding (#7838)

This commit is contained in:
jianan-gu
2025-07-09 17:04:21 +08:00
committed by GitHub
parent ac80f4da57
commit d389bedf72
2 changed files with 26 additions and 4 deletions

View File

@@ -83,13 +83,18 @@ def fused_topk_cpu(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
return torch.ops.sgl_kernel.topk_softmax_cpu(
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def fused_topk(
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu
fused_topk = fused_topk_cpu
else:
biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu

View File

@@ -187,11 +187,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
if _is_cpu:
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
shard_id * shard_size,
self.output_dim,
shard_size,
not use_presharded_weights,
)
else:
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"