[Refactor] remove some metadata variables in attention_v1. (#5160)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.

Todo:
1. remove attn_state partly.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-19 14:57:09 +08:00
committed by GitHub
parent bc05a81bf2
commit 35ad11b637
9 changed files with 41 additions and 53 deletions

View File

@@ -235,8 +235,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table, block_tables=block_table,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(), seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len, max_query_len=common_attn_metadata.max_query_len,

View File

@@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills, enable_cp, split_decodes_and_prefills,
using_paged_attention) using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
@@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
prefill_config = get_current_vllm_config().parallel_config if enable_cp():
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
from vllm_ascend.attention.attention_cp import \ from vllm_ascend.attention.attention_cp import \
AscendAttentionCPImpl AscendAttentionCPImpl
return AscendAttentionCPImpl return AscendAttentionCPImpl
@@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
prefill_config = get_current_vllm_config().parallel_config if enable_cp():
if (prefill_config.prefill_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1):
from vllm_ascend.attention.attention_cp import \ from vllm_ascend.attention.attention_cp import \
AscendAttentionCPMetadataBuilder AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder
@@ -191,10 +187,8 @@ class AscendMetadata:
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore
query_start_loc: torch.Tensor = None query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
# Maximum query length in the batch (None for decoding). # Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
@@ -214,9 +208,9 @@ class AscendMetadata:
# dcp # dcp
decode_meta: Optional[AscendMetadataForDecode] = None decode_meta: Optional[AscendMetadataForDecode] = None
# Whether is the pooling model with causal attention, causal: bool = True
# used to guide the attention computation for pooling models. # runner_type in model_config.
is_causal_pooling: Optional[bool] = None model_runner_type: str = ""
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
@@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
@@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder:
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to( query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
is_causal_pooling = None
if self.model_config.runner_type == "pooling":
is_causal_pooling = common_attn_metadata.causal if hasattr(
common_attn_metadata, 'causal') else True
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table, block_tables=block_table,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(), seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len, max_query_len=common_attn_metadata.max_query_len,
@@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder:
attn_state=attn_state, attn_state=attn_state,
num_prefills=num_prefills, num_prefills=num_prefills,
num_decodes=num_decodes, num_decodes=num_decodes,
is_causal_pooling=is_causal_pooling) causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type)
return attn_metadata return attn_metadata
def build_for_graph_capture( def build_for_graph_capture(
@@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value, block_size, block_table, actual_seq_lengths_kv \ key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata) = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.query_start_loc_list[-1] num_tokens = attn_metadata.actual_seq_lengths_q[-1]
graph_params = get_graph_params() graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output # Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level # TODO: Refactor this to step-level instead of layer-level
@@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=block_table, block_table=block_table,
input_layout="TND", input_layout="TND",
block_size=block_size, block_size=block_size,
actual_seq_lengths=query_start_loc, actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv, actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads, num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads, num_heads=self.num_heads,
@@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
(weak_ref_tensors(query), weak_ref_tensors(key), (weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table), weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask), block_size, weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output), self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse))) weak_ref_tensors(softmax_lse)))
@@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=block_table, block_table=block_table,
input_layout="TND", input_layout="TND",
block_size=block_size, block_size=block_size,
actual_seq_lengths=query_start_loc, actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv, actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads, num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads, num_heads=self.num_heads,
@@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128 block_size = 128
block_table = None block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
elif attn_metadata.attn_state == \ elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit: AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0] batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :] block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore key = self.key_cache.view( # type: ignore
@@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor: _: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None assert attn_metadata is not None
assert attn_metadata.is_causal_pooling is not None
if attn_metadata.is_causal_pooling: if attn_metadata.causal:
# use sparse_mode 3 in causal scenario # use sparse_mode 3 in causal scenario
return torch_npu.npu_fusion_attention( return torch_npu.npu_fusion_attention(
query=query, query=query,
@@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value = self.reshape_and_cache(key, value, kv_cache, key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata) attn_metadata)
# pooling model branch # pooling model branch
if isinstance(attn_metadata.is_causal_pooling, bool): if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention( attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output) query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]

View File

@@ -23,6 +23,7 @@ from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector, maybe_save_kv_layer_to_connector,
split_decodes_and_prefills, split_decodes_and_prefills,
trans_rope_weight, transdata, trans_rope_weight, transdata,
@@ -57,8 +58,7 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls(): def get_builder_cls():
prefill_config = get_current_vllm_config().parallel_config if enable_cp():
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder return AscendMLAMetadataBuilder
@@ -70,8 +70,7 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]: def get_impl_cls() -> Type["MLAAttentionImpl"]:
prefill_config = get_current_vllm_config().parallel_config if enable_cp():
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl return AscendMlaCPImpl
return AscendMLAImpl return AscendMLAImpl

View File

@@ -1,9 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional from typing import Any, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
@@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
return runtime_shape in get_ascend_config().pa_shape_list return runtime_shape in get_ascend_config().pa_shape_list
@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1
@dataclass @dataclass
# class AscendCommonLongSequenceMetadata: # class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata: class AscendPrefillContextParallelMetadata:
@@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
Per-batch attention metadata, shared across layers and backends. Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata. AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions. For many of the tensors we keep both NPU and CPU versions.
""" """
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
@@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:
attn_state: Any = None attn_state: Any = None
is_only_prefill: bool = False
graph_pad_size: int = -1 graph_pad_size: int = -1
# num_input_tokens refers to total number of tokens including # num_input_tokens refers to total number of tokens including
@@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[ prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None AscendPrefillContextParallelMetadata] = None
causal: bool = True
# TODO: Remove it when vLLM no longer uses this function. # TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int, def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata": num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
@@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata:
decode_token_per_req=self.decode_token_per_req, decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs], block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens], slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens], positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens, num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self. prefill_context_parallel_metadata=self.

View File

@@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
attn_output, softmax_lse) = param attn_output, softmax_lse) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list seq_lens = forward_context.attn_metadata[key].seq_lens_list
query_start_loc = forward_context.attn_metadata[ actual_seq_lengths_q = forward_context.attn_metadata[
key].query_start_loc_list key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
query=query, query=query,
@@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
atten_mask=attn_mask, atten_mask=attn_mask,
input_layout="TND", input_layout="TND",
block_size=block_size, block_size=block_size,
actual_seq_lengths=query_start_loc, actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens, actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads, num_key_value_heads=num_kv_heads,
num_heads=num_heads, num_heads=num_heads,

View File

@@ -384,12 +384,11 @@ class EagleProposer(Proposer):
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1] attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range( for now_speculative in range(

View File

@@ -1042,7 +1042,6 @@ class NPUModelRunner(GPUModelRunner):
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req, decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata, prefill_context_parallel_metadata=long_seq_metadata,

View File

@@ -45,7 +45,6 @@ def build_attn_metadata(
| None = None, | None = None,
spec_attn_mask: torch.Tensor | None = None, spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None, attn_state: Any | None = None,
is_only_prefill: bool = False,
graph_pad_size: int = -1, graph_pad_size: int = -1,
num_input_tokens: int = 0, num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
@@ -78,7 +77,6 @@ def build_attn_metadata(
attn_mask=attn_mask, attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask, spec_attn_mask=spec_attn_mask,
attn_state=attn_state, attn_state=attn_state,
is_only_prefill=is_only_prefill,
graph_pad_size=graph_pad_size, graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens, num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata, prefill_context_parallel_metadata=prefill_context_parallel_metadata,

View File

@@ -247,7 +247,9 @@ class XliteWrapper:
if not with_prefill or self.full_mode: if not with_prefill or self.full_mode:
batch = attn_metadata.num_prefills + attn_metadata.num_decodes batch = attn_metadata.num_prefills + attn_metadata.num_decodes
seq_lens = attn_metadata.seq_lens[:batch] seq_lens = attn_metadata.seq_lens[:batch]
query_lens = attn_metadata.query_lens[:batch] query_lens = attn_metadata.query_start_loc_cpu[
1:] - attn_metadata.query_start_loc_cpu[:-1]
query_lens = query_lens[:batch]
cached_lens = seq_lens - query_lens cached_lens = seq_lens - query_lens
xlite_attn_metadata = ModelAttnMeta() xlite_attn_metadata = ModelAttnMeta()