[BugFix] Fix muls_add fusion not working for GLM5 models (#6928)

### What this PR does / why we need it?
fix: support model-specific routed_scaling_factor in muls_add fusion
Previously, MulsAddFusionPass used a hardcoded scale=1.0, which failed
to match the x * routed_scaling_factor + y pattern in models like GLM5
that use routed_scaling_factor=2.5. This caused the muls_add fusion to
be skipped, leaving unoptimized mul+add operations.

This fix reads routed_scaling_factor from model config (defaulting to
1.0
for backward compatibility) and uses it as the pattern scale, enabling
correct fusion for GLM5 and other models with custom scaling factors.

Fixes: Unoptimized mul+add in GLM5 attention blocks
Tested: GLM5-W8A8 with routed_scaling_factor=2.5
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchen2026fly
2026-03-05 22:35:54 +08:00
committed by GitHub
parent ae394767d4
commit 640ecd1b77

View File

@@ -95,10 +95,8 @@ class MulsAddFusionPass(VllmInductorPass):
logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype) logger.debug("MulsAdd fusion not enabled: unsupported dtype %s", dtype)
return return
# Currently we only register a single pattern instance with a fixed routed_scaling_factor = getattr(vllm_config.model_config.hf_text_config, "routed_scaling_factor", 1.0)
# scalar scale value. If needed, multiple instances with different MulsAddPattern(vllm_config, scale=routed_scaling_factor).register(self.pattern_match_passes)
# scales can be added here in the future.
MulsAddPattern(vllm_config, scale=1.0).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override] def __call__(self, graph: torch.fx.Graph) -> None: # type: ignore[override]
self.begin() self.begin()