[CPU][Qwen3 MoE] Enable fused_topk CPU fusion and enhance FP8 TP padding (#7838)
This commit is contained in:
@@ -83,13 +83,18 @@ def fused_topk_cpu(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_topk(
|
||||
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
|
||||
biased_grouped_topk = biased_grouped_topk_cpu
|
||||
grouped_topk = grouped_topk_cpu
|
||||
fused_topk_native = fused_topk_cpu
|
||||
fused_topk = fused_topk_cpu
|
||||
else:
|
||||
biased_grouped_topk = biased_grouped_topk_gpu
|
||||
grouped_topk = grouped_topk_gpu
|
||||
|
||||
@@ -187,11 +187,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data = self.data
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
|
||||
if _is_cpu:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
shard_id * shard_size,
|
||||
self.output_dim,
|
||||
shard_size,
|
||||
not use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
|
||||
Reference in New Issue
Block a user