[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)
This commit is contained in:
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size", intermediate_padding_size
|
||||
)
|
||||
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size_mlp", intermediate_padding_size
|
||||
)
|
||||
return model_config
|
||||
|
||||
@@ -93,6 +93,19 @@ def fused_topk_cpu(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
||||
if not need_apply:
|
||||
return inputs, topk_weights
|
||||
|
||||
# TODO: fuse below processing in fused_experts_cpu kernel
|
||||
inputs = inputs * topk_weights.to(inputs.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # clear topk_weights as already applied
|
||||
|
||||
return inputs, topk_weights
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
|
||||
@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
|
||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||
if use_intel_amx_backend(layer):
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
select_experts,
|
||||
apply_topk_weights_cpu,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
|
||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
Reference in New Issue
Block a user