From 640ecd1b772b1c3dcdc57336b762cc02d011eba8 Mon Sep 17 00:00:00 2001 From: liuchen2026fly Date: Thu, 5 Mar 2026 22:35:54 +0800 Subject: [PATCH] [BugFix] Fix muls_add fusion not working for GLM5 models (#6928) ### What this PR does / why we need it? fix: support model-specific routed_scaling_factor in muls_add fusion Previously, MulsAddFusionPass used a hardcoded scale=1.0, which failed to match the x * routed_scaling_factor + y pattern in models like GLM5 that use routed_scaling_factor=2.5. This caused the muls_add fusion to be skipped, leaving unoptimized mul+add operations. This fix reads routed_scaling_factor from model config (defaulting to 1.0 for backward compatibility) and uses it as the pattern scale, enabling correct fusion for GLM5 and other models with custom scaling factors. Fixes: Unoptimized mul+add in GLM5 attention blocks Tested: GLM5-W8A8 with routed_scaling_factor=2.5 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 Signed-off-by: liuchenbing Co-authored-by: liuchenbing --- vllm_ascend/compilation/passes/muls_add_pass.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/compilation/passes/muls_add_pass.py b/vllm_ascend/compilation/passes/muls_add_pass.py index 3d4ed764..0a379d17 100644 --- a/vllm_ascend/compilation/passes/muls_add_pass.py +++ b/vllm_ascend/compilation/passes/muls_add_pass.py @@ -95,10 +95,8 @@ class MulsAddFusionPass(VllmInductorPass): logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype) return - # Currently we only register a single pattern instance with a fixed - # scalar scale value. If needed, multiple instances with different - # scales can be added here in the future. - MulsAddPattern(vllm_config, scale=1.0).register(self.pattern_match_passes) + routed_scaling_factor = getattr(vllm_config.model_config.hf_text_config, "routed_scaling_factor", 1.0) + MulsAddPattern(vllm_config, scale=routed_scaling_factor).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override] self.begin()