[Refactor] remove some metadata variables in attention_v1. (#5160)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.

Todo:
1. remove attn_state partly.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-19 14:57:09 +08:00
committed by GitHub
parent bc05a81bf2
commit 35ad11b637
9 changed files with 41 additions and 53 deletions

View File

@@ -235,8 +235,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,

View File

@@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills,
enable_cp, split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
@@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
if enable_cp():
from vllm_ascend.attention.attention_cp import \
AscendAttentionCPImpl
return AscendAttentionCPImpl
@@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
prefill_config = get_current_vllm_config().parallel_config
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
if enable_cp():
from vllm_ascend.attention.attention_cp import \
AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder
@@ -191,10 +187,8 @@ class AscendMetadata:
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore
query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None
@@ -214,9 +208,9 @@ class AscendMetadata:
# dcp
decode_meta: Optional[AscendMetadataForDecode] = None
# Whether is the pooling model with causal attention,
# used to guide the attention computation for pooling models.
is_causal_pooling: Optional[bool] = None
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""
class AscendAttentionMetadataBuilder:
@@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
@@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder:
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
is_causal_pooling = None
if self.model_config.runner_type == "pooling":
is_causal_pooling = common_attn_metadata.causal if hasattr(
common_attn_metadata, 'causal') else True
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
@@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder:
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
is_causal_pooling=is_causal_pooling)
causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type)
return attn_metadata
def build_for_graph_capture(
@@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.query_start_loc_list[-1]
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
@@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
@@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
@@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
@@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
@@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.is_causal_pooling is not None
if attn_metadata.is_causal_pooling:
if attn_metadata.causal:
# use sparse_mode 3 in causal scenario
return torch_npu.npu_fusion_attention(
query=query,
@@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
# pooling model branch
if isinstance(attn_metadata.is_causal_pooling, bool):
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]

View File

@@ -23,6 +23,7 @@ from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
@@ -57,8 +58,7 @@ class AscendMLABackend(AttentionBackend):
@staticmethod
def get_builder_cls():
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder
@@ -70,8 +70,7 @@ class AscendMLABackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl

View File

@@ -1,9 +1,10 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
@@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
return runtime_shape in get_ascend_config().pa_shape_list
@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1
@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
@@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
For many of the tensors we keep both NPU and CPU versions.
"""
query_start_loc: torch.Tensor
@@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:
attn_state: Any = None
is_only_prefill: bool = False
graph_pad_size: int = -1
# num_input_tokens refers to total number of tokens including
@@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
causal: bool = True
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
@@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata:
decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.

View File

@@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
attn_output, softmax_lse) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list
query_start_loc = forward_context.attn_metadata[
key].query_start_loc_list
actual_seq_lengths_q = forward_context.attn_metadata[
key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
@@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,

View File

@@ -384,12 +384,11 @@ class EagleProposer(Proposer):
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range(

View File

@@ -1042,7 +1042,6 @@ class NPUModelRunner(GPUModelRunner):
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,

View File

@@ -45,7 +45,6 @@ def build_attn_metadata(
| None = None,
spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None,
is_only_prefill: bool = False,
graph_pad_size: int = -1,
num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
@@ -78,7 +77,6 @@ def build_attn_metadata(
attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask,
attn_state=attn_state,
is_only_prefill=is_only_prefill,
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata,

View File

@@ -247,7 +247,9 @@ class XliteWrapper:
if not with_prefill or self.full_mode:
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
seq_lens = attn_metadata.seq_lens[:batch]
query_lens = attn_metadata.query_lens[:batch]
query_lens = attn_metadata.query_start_loc_cpu[
1:] - attn_metadata.query_start_loc_cpu[:-1]
query_lens = query_lens[:batch]
cached_lens = seq_lens - query_lens
xlite_attn_metadata = ModelAttnMeta()