diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py index 94a08389..f87413c8 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py @@ -16,6 +16,8 @@ # import torch import torchair +from torch._inductor.pattern_matcher import Match +from vllm.compilation.inductor_pass import get_pass_context from vllm.config import VllmConfig from vllm.config.compilation import Range from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce @@ -27,6 +29,11 @@ from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check im ALLREDUCE_NORM_FUSE_THREHOLD = 512 +def extra_check_for_allreduce_rmsnorm_fusion_pass(match: Match) -> bool: + compile_range = get_pass_context().compile_range + return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD + + class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern: """ recognizing the Matmul + AllReduce + AddRMSNorm computation pattern @@ -80,7 +87,7 @@ class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern: search_fn=pattern, replace_fn=replacement, example_inputs=self.get_inputs(), - extra_check=extra_stream_scope_check, + extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass, ) @@ -130,7 +137,7 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern: search_fn=pattern, replace_fn=replacement, example_inputs=self.get_inputs(), - extra_check=extra_stream_scope_check, + extra_check=extra_check_for_allreduce_rmsnorm_fusion_pass, )