From 7b3921c498d64081e507e03eef6edba15e8428b1 Mon Sep 17 00:00:00 2001 From: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Date: Wed, 4 Feb 2026 08:48:28 +0800 Subject: [PATCH] [bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (#6430) ### What this PR does / why we need it? Allreduce rmsnorm fusion pass has an additional check condition, which requires fusion of the Fx graph only when the start of compile_range is greater than 512. We previously overlooked this check. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: chencangtao Co-authored-by: chencangtao --- .../graphex_allreduce_rmsnorm_fusion_pass.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py index 94a08389..f87413c8 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py @@ -16,6 +16,8 @@ # import torch import torchair +from torch._inductor.pattern_matcher import Match +from vllm.compilation.inductor_pass import get_pass_context from vllm.config import VllmConfig from vllm.config.compilation import Range from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce @@ -27,6 +29,11 @@ from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check im ALLREDUCE_NORM_FUSE_THREHOLD = 512 +def extra_check_for_allreduce_rmsnorm_fusion_pass(match: Match) -> bool: + compile_range = get_pass_context().compile_range + return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD + + class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern: """ recognizing the Matmul + AllReduce + AddRMSNorm computation pattern @@ -80,7 +87,7 @@ class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern: search_fn=pattern, replace_fn=replacement, example_inputs=self.get_inputs(), - extra_check=extra_stream_scope_check, + extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass, ) @@ -130,7 +137,7 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern: search_fn=pattern, replace_fn=replacement, example_inputs=self.get_inputs(), - extra_check=extra_stream_scope_check, + extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass, )