diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7d3d18f..fca31df 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -199,8 +199,8 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: - # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = \ + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index bae0cec..7a85bbb 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) import torch import torch_npu @@ -12,6 +13,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -163,6 +165,9 @@ M = TypeVar("M", bound=AscendMLAMetadata) class AscendMLAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER """ NOTE: Please read the comment at the top of the file before trying to understand this class diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9421a98..0e77d43 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3259,8 +3259,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): builder = attn_group.metadata_builder else: builder = attn_group.get_metadata_builder() - if builder.cudagraph_support.value < min_ag_support.value: - min_ag_support = builder.cudagraph_support + if builder.aclgraph_support.value < min_ag_support.value: + min_ag_support = builder.aclgraph_support min_ag_builder_name = builder.__class__.__name__ # This is an imitation of compilation_config.splitting_ops_contain_attention()