From 48c1fa7bb6950b81788a84da32c3c42bc7c77e67 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 18 Jul 2025 12:43:25 +0800 Subject: [PATCH] [CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889) --- python/sglang/srt/configs/update_config.py | 4 +++- python/sglang/srt/layers/moe/topk.py | 13 +++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 6 ++++++ python/sglang/srt/layers/quantization/unquant.py | 11 ++++++++--- python/sglang/srt/layers/quantization/w8a8_int8.py | 5 +++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index f9e6d15a8..241d9566a 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp( model_config = update_intermediate_size( model_config, "intermediate_size", intermediate_padding_size ) - + model_config = update_intermediate_size( + model_config, "intermediate_size_mlp", intermediate_padding_size + ) return model_config diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 1c8d219e4..40fc0b61f 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -93,6 +93,19 @@ def fused_topk_cpu( return topk_weights, topk_ids +def apply_topk_weights_cpu(need_apply, topk_weights, inputs): + if not need_apply: + return inputs, topk_weights + + # TODO: fuse below processing in fused_experts_cpu kernel + inputs = inputs * topk_weights.to(inputs.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # clear topk_weights as already applied + + return inputs, topk_weights + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 38588c809..7275ea430 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) if use_intel_amx_backend(layer): + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) + return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 28d006255..821b1cb85 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) -> torch.Tensor: assert activation == "silu", f"activation = {activation} is not supported." - if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + if use_intel_amx_backend(layer): - from sglang.srt.layers.moe.topk import select_experts + from sglang.srt.layers.moe.topk import ( + select_experts, + apply_topk_weights_cpu, + ) topk_weights, topk_ids = select_experts( hidden_states=x, @@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, ) + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) - # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index c8a024bf3..56ac26c57 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ) if use_intel_amx_backend(layer): + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight,