[Misc] Clean up uesless code in attention (#1933)

Before do attention module refactor, we can do some code cleanup to make
the next step easier.

What this PR does:

1. remove uesless `common_prefix_len` for attention builder
2. remove uesless `is_only_prefill` and `num_input_tokens` in attention
metadata.
3. remove `CommonAttentionMetadata` and ues `query_start_loc` instead,
`CommonAttentionMetadata` is over designed and uesless
4. update the attention backend input parameters to keep the same as
vLLM.
5. Rename attention name to the same style with `ASCEND` prefix

- vLLM version: v0.9.2
- vLLM main:
107111a859

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-24 10:23:34 +08:00
committed by GitHub
parent b5ad70e1a6
commit 846555cdb5
7 changed files with 41 additions and 93 deletions

View File

@@ -130,10 +130,8 @@ class AscendMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -141,18 +139,9 @@ class AscendMetadata:
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# TODO: Indicates whether there are only prefill requests.
# FlashAttention can be used when there are only prefill requests.
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
@@ -169,7 +158,6 @@ class AscendAttentionMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
@@ -224,10 +212,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size

View File

@@ -37,7 +37,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ASCEND"
return "ASCEND_TORCHAIR"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
@@ -129,10 +129,8 @@ class AscendTorchairMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -140,20 +138,10 @@ class AscendTorchairMetadata:
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# TODO: Indicates whether there are only prefill requests.
# FlashAttention can be used when there are only prefill requests.
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
decode: Optional[AscendDecodeMetadata] = None
@@ -236,7 +224,6 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
@@ -335,10 +322,10 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size

View File

@@ -31,27 +31,13 @@ if TYPE_CHECKING:
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "VLLM_ASCEND_MLA"
return "ASCEND_MLA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
@@ -368,11 +354,10 @@ class AscendMLAMetadataBuilder:
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False,
query_start_loc: torch.Tensor = None,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
@@ -394,7 +379,6 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
query_start_loc = common_attn_metadata.query_start_loc
prefill_metadata = None
chunked_context_metadata = None
@@ -403,7 +387,6 @@ class AscendMLAMetadataBuilder:
tokens_start = self._num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
query_start_loc = common_attn_metadata.query_start_loc
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
@@ -539,7 +522,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads

View File

@@ -132,7 +132,6 @@ class EagleProposer:
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
)
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:

View File

@@ -75,8 +75,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata,
CommonAttentionMetadata)
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
@@ -694,15 +693,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
attn_metadata_i = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@@ -1049,27 +1043,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
if self.vllm_config.model_config.use_mla:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
extra_builder_kwargs[
"query_start_loc"] = self.query_start_loc[:num_reqs + 1]
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_attn_metadata=common_attn_metadata,
common_prefix_len=None,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
else:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
# Prepare input_ids
token_indices = (positions_np +

View File

@@ -8,7 +8,6 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
@@ -100,11 +99,6 @@ class MtpProposer:
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
seq_lens = (target_positions[last_token_indices] + 1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
@@ -120,8 +114,7 @@ class MtpProposer:
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
query_start_loc=cu_num_tokens,
)
with set_forward_context(attn_metadata, self.vllm_config):