[Fusion] [Graph]Add Matmul Allreduce Rmsnorm fusion Pass (#5034)

This PR add `MatmulAllreduceRmsnorm` operator and introduces a graph
fusion pass for `matmul_allreduce_rmsnorm` operations. The
implementation includes a new configuration flag, a pattern matching
pass using `torch._inductor.pattern_matcher`.

Co-authored-by: Trunrain [270250579@qq.com](mailto:270250579@qq.com)

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
Icey
2026-01-19 09:28:07 +08:00
committed by GitHub
parent 9cad1a8349
commit c929bd1e8d
8 changed files with 251 additions and 1 deletions

View File

@@ -192,6 +192,18 @@ class NPUPlatform(Platform):
else ascend_compilation_config
)
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
if model_config is None: