[bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (#6430)

### What this PR does / why we need it?
Allreduce rmsnorm fusion pass has an additional check condition, which
requires fusion of the Fx graph only when the start of compile_range is
greater than 512. We previously overlooked this check.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
ChenCangtao
2026-02-04 08:48:28 +08:00
committed by GitHub
parent a80e524fbc
commit 7b3921c498

View File

@@ -16,6 +16,8 @@
#
import torch
import torchair
from torch._inductor.pattern_matcher import Match
from vllm.compilation.inductor_pass import get_pass_context
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
@@ -27,6 +29,11 @@ from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check im
ALLREDUCE_NORM_FUSE_THREHOLD = 512
def extra_check_for_allreduce_rmsnorm_fusion_pass(match: Match) -> bool:
compile_range = get_pass_context().compile_range
return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
"""
recognizing the Matmul + AllReduce + AddRMSNorm computation pattern
@@ -80,7 +87,7 @@ class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
search_fn=pattern,
replace_fn=replacement,
example_inputs=self.get_inputs(),
extra_check=extra_stream_scope_check,
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
)
@@ -130,7 +137,7 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
search_fn=pattern,
replace_fn=replacement,
example_inputs=self.get_inputs(),
extra_check=extra_stream_scope_check,
extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass,
)