[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)

This commit is contained in:
jianan-gu
2025-07-18 12:43:25 +08:00
committed by GitHub
parent 8aa5ae6b04
commit 48c1fa7bb6
5 changed files with 35 additions and 4 deletions

View File

@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
return model_config

View File

@@ -93,6 +93,19 @@ def fused_topk_cpu(
return topk_weights, topk_ids
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
if not need_apply:
return inputs, topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs = inputs * topk_weights.to(inputs.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # clear topk_weights as already applied
return inputs, topk_weights
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,

View File

@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,

View File

@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,

View File

@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,