[bugfix][npugraph_ex]duplicate pattern issue (#6513)
### What this PR does / why we need it? When the draft model also uses vllmbackend for graph compilation, the fusion pass registration occurs again, resulting in errors due to duplicate patterns. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
@@ -23,7 +23,10 @@ from vllm.config.compilation import Range
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
|
||||
check_and_register_fusion_pass,
|
||||
extra_stream_scope_check,
|
||||
)
|
||||
|
||||
# computation-communication tiling block is 512
|
||||
ALLREDUCE_NORM_FUSE_THREHOLD = 512
|
||||
@@ -143,8 +146,8 @@ class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
|
||||
|
||||
class GraphEXMatmulAllReduceAddRMSNormPass:
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
|
||||
GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
|
||||
check_and_register_fusion_pass(GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)
|
||||
check_and_register_fusion_pass(GraphEXLastLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
@@ -22,7 +22,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
|
||||
check_and_register_fusion_pass,
|
||||
extra_stream_scope_check,
|
||||
)
|
||||
|
||||
|
||||
class GraphEXAddRMSNormQuantPattern:
|
||||
@@ -301,10 +304,10 @@ class GraphEXAddRMSNormFusionPass:
|
||||
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register()
|
||||
GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register()
|
||||
GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register()
|
||||
GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register()
|
||||
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPattern, vllm_config=vllm_config, eps=eps)
|
||||
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPatternWithBias, vllm_config=vllm_config, eps=eps)
|
||||
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPattern, vllm_config=vllm_config, eps=eps)
|
||||
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config=vllm_config, eps=eps)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
@@ -23,7 +23,10 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
|
||||
check_and_register_fusion_pass,
|
||||
extra_stream_scope_check,
|
||||
)
|
||||
|
||||
|
||||
class GraphEXQKNormRopeFusionPattern:
|
||||
@@ -202,20 +205,22 @@ class GraphEXQKNormRopeFusionPass:
|
||||
if layer.head_size != 128:
|
||||
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
|
||||
continue
|
||||
GraphEXQKNormRopeFusionPattern(
|
||||
check_and_register_fusion_pass(
|
||||
GraphEXQKNormRopeFusionPattern,
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register()
|
||||
GraphEXQKNormRopeFusionPatternWithBias(
|
||||
)
|
||||
check_and_register_fusion_pass(
|
||||
GraphEXQKNormRopeFusionPatternWithBias,
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register()
|
||||
)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
@@ -51,3 +51,25 @@ def extra_stream_scope_check(match: Match) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_register_patterns = set()
|
||||
|
||||
|
||||
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
|
||||
global _register_patterns
|
||||
eps = kwargs.get("eps", 1e-6)
|
||||
pattern_key = str(pattern_class.__name__) + str(eps)
|
||||
if pattern_key in _register_patterns:
|
||||
return
|
||||
|
||||
pattern = pattern_class(**kwargs)
|
||||
try:
|
||||
pattern.register()
|
||||
_register_patterns.add(pattern_key)
|
||||
except RuntimeError as e:
|
||||
if "Duplicate pattern" in str(e):
|
||||
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
|
||||
_register_patterns.add(pattern_key)
|
||||
else:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user