[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)
This commit is contained in:
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
|
|||||||
model_config = update_intermediate_size(
|
model_config = update_intermediate_size(
|
||||||
model_config, "intermediate_size", intermediate_padding_size
|
model_config, "intermediate_size", intermediate_padding_size
|
||||||
)
|
)
|
||||||
|
model_config = update_intermediate_size(
|
||||||
|
model_config, "intermediate_size_mlp", intermediate_padding_size
|
||||||
|
)
|
||||||
return model_config
|
return model_config
|
||||||
|
|||||||
@@ -93,6 +93,19 @@ def fused_topk_cpu(
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
||||||
|
if not need_apply:
|
||||||
|
return inputs, topk_weights
|
||||||
|
|
||||||
|
# TODO: fuse below processing in fused_experts_cpu kernel
|
||||||
|
inputs = inputs * topk_weights.to(inputs.dtype)
|
||||||
|
topk_weights = torch.ones_like(
|
||||||
|
topk_weights, dtype=torch.float32
|
||||||
|
) # clear topk_weights as already applied
|
||||||
|
|
||||||
|
return inputs, topk_weights
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
|
|||||||
@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_intel_amx_backend(layer):
|
if use_intel_amx_backend(layer):
|
||||||
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
|
apply_router_weight_on_input, topk_weights, x
|
||||||
|
)
|
||||||
|
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", f"activation = {activation} is not supported."
|
assert activation == "silu", f"activation = {activation} is not supported."
|
||||||
|
|
||||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
if use_intel_amx_backend(layer):
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import (
|
||||||
|
select_experts,
|
||||||
|
apply_topk_weights_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
|
apply_router_weight_on_input, topk_weights, x
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_intel_amx_backend(layer):
|
if use_intel_amx_backend(layer):
|
||||||
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
|
apply_router_weight_on_input, topk_weights, x
|
||||||
|
)
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
Reference in New Issue
Block a user