[bugfix][npugraph_ex]duplicate pattern issue (#6513)

### What this PR does / why we need it?
When the draft model also uses vllmbackend for graph compilation, the
fusion pass registration occurs again, resulting in errors due to
duplicate patterns.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
ChenCangtao
2026-02-04 08:49:13 +08:00
committed by GitHub
parent 7b3921c498
commit fa56abea9f
4 changed files with 46 additions and 13 deletions

View File

@@ -23,7 +23,10 @@ from vllm.config.compilation import Range
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)
# computation-communication tiling block is 512
ALLREDUCE_NORM_FUSE_THREHOLD = 512
@@ -143,8 +146,8 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
class GraphEXMatmulAllReduceAddRMSNormPass:
def __init__(self, vllm_config: VllmConfig):
GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
check_and_register_fusion_pass(GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)
check_and_register_fusion_pass(GraphEXLastLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)
def __call__(self, graph: torch.fx.Graph):
pass

View File

@@ -22,7 +22,10 @@ from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)
class GraphEXAddRMSNormQuantPattern:
@@ -301,10 +304,10 @@ class GraphEXAddRMSNormFusionPass:
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register()
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPattern, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPatternWithBias, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPattern, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config=vllm_config, eps=eps)
def __call__(self, graph: torch.fx.Graph):
pass

View File

@@ -23,7 +23,10 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)
class GraphEXQKNormRopeFusionPattern:
@@ -202,20 +205,22 @@ class GraphEXQKNormRopeFusionPass:
if layer.head_size != 128:
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
continue
GraphEXQKNormRopeFusionPattern(
check_and_register_fusion_pass(
GraphEXQKNormRopeFusionPattern,
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register()
GraphEXQKNormRopeFusionPatternWithBias(
)
check_and_register_fusion_pass(
GraphEXQKNormRopeFusionPatternWithBias,
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register()
)
def __call__(self, graph: torch.fx.Graph):
pass

View File

@@ -51,3 +51,25 @@ def extra_stream_scope_check(match: Match) -> bool:
return False
return True
_register_patterns = set()
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _register_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _register_patterns:
return
pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e