[Refactor] remove some metadata variables in attention_v1. (#5160)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.

Todo:
1. remove attn_state partly.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-19 14:57:09 +08:00
committed by GitHub
parent bc05a81bf2
commit 35ad11b637
9 changed files with 41 additions and 53 deletions

View File

@@ -1,9 +1,10 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
@@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
return runtime_shape in get_ascend_config().pa_shape_list
@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1
@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
@@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
For many of the tensors we keep both NPU and CPU versions.
"""
query_start_loc: torch.Tensor
@@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:
attn_state: Any = None
is_only_prefill: bool = False
graph_pad_size: int = -1
# num_input_tokens refers to total number of tokens including
@@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
causal: bool = True
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
@@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata:
decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.