From fa56abea9faf2d755ca4a800f59666bcf7dd34a6 Mon Sep 17 00:00:00 2001 From: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Date: Wed, 4 Feb 2026 08:49:13 +0800 Subject: [PATCH] [bugfix][npugraph_ex]duplicate pattern issue (#6513) ### What this PR does / why we need it? When the draft model also uses vllmbackend for graph compilation, the fusion pass registration occurs again, resulting in errors due to duplicate patterns. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: chencangtao Co-authored-by: chencangtao --- .../graphex_allreduce_rmsnorm_fusion_pass.py | 9 +++++--- .../graphex_norm_quant_fusion_pass.py | 13 ++++++----- .../graphex_qknorm_rope_fusion_pass.py | 15 ++++++++----- .../utils/npugraph_ex_utils_check.py | 22 +++++++++++++++++++ 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py index f87413c8..250e7df7 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py @@ -23,7 +23,10 @@ from vllm.config.compilation import Range from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tp_group -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) # computation-communication tiling block is 512 ALLREDUCE_NORM_FUSE_THREHOLD = 512 @@ -143,8 +146,8 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern: class GraphEXMatmulAllReduceAddRMSNormPass: def __init__(self, vllm_config: VllmConfig): - GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() - GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() + check_and_register_fusion_pass(GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config) + check_and_register_fusion_pass(GraphEXLastLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py index 5c41100a..1534b038 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py @@ -22,7 +22,10 @@ from vllm.config import VllmConfig from vllm.config.compilation import Range from vllm.logger import logger -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) class GraphEXAddRMSNormQuantPattern: @@ -301,10 +304,10 @@ class GraphEXAddRMSNormFusionPass: common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: - GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register() + check_and_register_fusion_pass(GraphEXAddRMSNormQuantPattern, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantPatternWithBias, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPattern, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config=vllm_config, eps=eps) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py index 3317d132..8586e6d9 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py @@ -23,7 +23,10 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config.compilation import Range from vllm.logger import logger -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) class GraphEXQKNormRopeFusionPattern: @@ -202,20 +205,22 @@ class GraphEXQKNormRopeFusionPass: if layer.head_size != 128: logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size) continue - GraphEXQKNormRopeFusionPattern( + check_and_register_fusion_pass( + GraphEXQKNormRopeFusionPattern, vllm_config=vllm_config, head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon, - ).register() - GraphEXQKNormRopeFusionPatternWithBias( + ) + check_and_register_fusion_pass( + GraphEXQKNormRopeFusionPatternWithBias, vllm_config=vllm_config, head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon, - ).register() + ) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py index 481a16ed..a81dbcc9 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py @@ -51,3 +51,25 @@ def extra_stream_scope_check(match: Match) -> bool: return False return True + + +_register_patterns = set() + + +def check_and_register_fusion_pass(pattern_class: type, **kwargs): + global _register_patterns + eps = kwargs.get("eps", 1e-6) + pattern_key = str(pattern_class.__name__) + str(eps) + if pattern_key in _register_patterns: + return + + pattern = pattern_class(**kwargs) + try: + pattern.register() + _register_patterns.add(pattern_key) + except RuntimeError as e: + if "Duplicate pattern" in str(e): + logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered") + _register_patterns.add(pattern_key) + else: + raise e