From 35ad11b6378fff64554c6905dc25c866b0c0a353 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Fri, 19 Dec 2025 14:57:09 +0800 Subject: [PATCH] [Refactor] remove some metadata variables in attention_v1. (#5160) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: The metadata data class contains an excessive number of variables. We will inherit the metadata of the community and simultaneously remove some variables that are no longer needed at present. Todo: 1. remove attn_state partly. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/attention/attention_cp.py | 2 - vllm_ascend/attention/attention_v1.py | 49 ++++++++--------------- vllm_ascend/attention/mla_v1.py | 7 ++-- vllm_ascend/attention/utils.py | 18 ++++++--- vllm_ascend/compilation/acl_graph.py | 6 +-- vllm_ascend/spec_decode/eagle_proposer.py | 5 +-- vllm_ascend/worker/model_runner_v1.py | 1 - vllm_ascend/worker/v2/attn_utils.py | 2 - vllm_ascend/xlite/xlite.py | 4 +- 9 files changed, 41 insertions(+), 53 deletions(-) diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 66961a52..95ac7fcb 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -235,8 +235,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, - query_start_loc_list=query_start_loc_cpu[1:].tolist(), - query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 2abc18f9..f4107aa2 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills, + enable_cp, split_decodes_and_prefills, using_paged_attention) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) @@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: - prefill_config = get_current_vllm_config().parallel_config - if (prefill_config.prefill_context_parallel_size > 1 - or prefill_config.decode_context_parallel_size > 1): + if enable_cp(): from vllm_ascend.attention.attention_cp import \ AscendAttentionCPImpl return AscendAttentionCPImpl @@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: - prefill_config = get_current_vllm_config().parallel_config - if (prefill_config.prefill_context_parallel_size > 1 - or prefill_config.decode_context_parallel_size > 1): + if enable_cp(): from vllm_ascend.attention.attention_cp import \ AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder @@ -191,10 +187,8 @@ class AscendMetadata: seq_lens: torch.Tensor = None seq_lens_list: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore - query_start_loc_list: List[int] = None # type: ignore query_start_loc: torch.Tensor = None - query_lens: torch.Tensor = None # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None @@ -214,9 +208,9 @@ class AscendMetadata: # dcp decode_meta: Optional[AscendMetadataForDecode] = None - # Whether is the pooling model with causal attention, - # used to guide the attention computation for pooling models. - is_causal_pooling: Optional[bool] = None + causal: bool = True + # runner_type in model_config. + model_runner_type: str = "" class AscendAttentionMetadataBuilder: @@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens block_table = common_attn_metadata.block_table_tensor - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata @@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder: # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device query_start_loc = query_start_loc_cpu.pin_memory().to( self.device, non_blocking=True) - is_causal_pooling = None - if self.model_config.runner_type == "pooling": - is_causal_pooling = common_attn_metadata.causal if hasattr( - common_attn_metadata, 'causal') else True attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, - query_start_loc_list=query_start_loc_cpu[1:].tolist(), - query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, @@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder: attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, - is_causal_pooling=is_causal_pooling) + causal=common_attn_metadata.causal, + model_runner_type=self.model_config.runner_type) return attn_metadata def build_for_graph_capture( @@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl): key, value, block_size, block_table, actual_seq_lengths_kv \ = self._get_fia_params(key, value, attn_metadata) - num_tokens = attn_metadata.query_start_loc_list[-1] + num_tokens = attn_metadata.actual_seq_lengths_q[-1] graph_params = get_graph_params() - query_start_loc = attn_metadata.query_start_loc_list + actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q # Prepare tensors for attention output # TODO: Refactor this to step-level instead of layer-level @@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=block_table, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=actual_seq_lengths_kv, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, @@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl): (weak_ref_tensors(query), weak_ref_tensors(key), weak_ref_tensors(value), weak_ref_tensors(block_table), weak_ref_tensors(attn_metadata.attn_mask), block_size, - actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, + actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads, self.num_heads, self.scale, weak_ref_tensors(output), weak_ref_tensors(softmax_lse))) @@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=block_table, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=actual_seq_lengths_kv, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, @@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl): if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None - actual_seq_lengths_kv = attn_metadata.query_start_loc_list + actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q elif attn_metadata.attn_state == \ AscendAttentionState.PrefillCacheHit: - batch_size = attn_metadata.query_lens.shape[0] + batch_size = attn_metadata.seq_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore @@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, _: torch.Tensor) -> torch.Tensor: assert attn_metadata is not None - assert attn_metadata.is_causal_pooling is not None - if attn_metadata.is_causal_pooling: + if attn_metadata.causal: # use sparse_mode 3 in causal scenario return torch_npu.npu_fusion_attention( query=query, @@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl): key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) # pooling model branch - if isinstance(attn_metadata.is_causal_pooling, bool): + if attn_metadata.model_runner_type == "pooling": attn_output = self._forward_encoder_attention( query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 17a5f04e..352f1d33 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -23,6 +23,7 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, trans_rope_weight, transdata, @@ -57,8 +58,7 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_builder_cls(): - prefill_config = get_current_vllm_config().parallel_config - if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1: + if enable_cp(): from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder return AscendMLAMetadataBuilder @@ -70,8 +70,7 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["MLAAttentionImpl"]: - prefill_config = get_current_vllm_config().parallel_config - if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1: + if enable_cp(): from vllm_ascend.attention.mla_cp import AscendMlaCPImpl return AscendMlaCPImpl return AscendMLAImpl diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index ac19dc9d..22200256 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,9 +1,10 @@ from dataclasses import dataclass +from functools import lru_cache from typing import Any, List, Optional import torch import torch.nn.functional as F -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) @@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: return runtime_shape in get_ascend_config().pa_shape_list +@lru_cache(maxsize=1) +def enable_cp(): + prefill_config = get_current_vllm_config().parallel_config + return prefill_config.prefill_context_parallel_size > 1 \ + or prefill_config.decode_context_parallel_size > 1 + + @dataclass # class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: @@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata: Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - For many of the tensors we keep both GPU and CPU versions. + For many of the tensors we keep both NPU and CPU versions. """ query_start_loc: torch.Tensor @@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata: attn_state: Any = None - is_only_prefill: bool = False - graph_pad_size: int = -1 # num_input_tokens refers to total number of tokens including @@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata: prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None + causal: bool = True + # TODO: Remove it when vLLM no longer uses this function. def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata": @@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata: decode_token_per_req=self.decode_token_per_req, block_table_tensor=self.block_table_tensor[:num_actual_reqs], slot_mapping=self.slot_mapping[:num_actual_tokens], + causal=self.causal, actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], positions=self.positions[:num_actual_tokens], attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - is_only_prefill=self.is_only_prefill, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=num_actual_tokens, prefill_context_parallel_metadata=self. diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index c4990497..c3c52cb2 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): attn_output, softmax_lse) = param seq_lens = forward_context.attn_metadata[key].seq_lens_list - query_start_loc = forward_context.attn_metadata[ - key].query_start_loc_list + actual_seq_lengths_q = forward_context.attn_metadata[ + key].actual_seq_lengths_q torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( query=query, @@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): atten_mask=attn_mask, input_layout="TND", block_size=block_size, - actual_seq_lengths=query_start_loc, + actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=seq_lens, num_key_value_heads=num_kv_heads, num_heads=num_heads, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 0518aa4a..356efad6 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -384,12 +384,11 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1] - attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ - 1:].tolist() attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens - attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] + attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[ + 1:].tolist() attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 94ae0916..38cfcd0c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1042,7 +1042,6 @@ class NPUModelRunner(GPUModelRunner): attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=long_seq_metadata, diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index ad2caa55..eb536eaa 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -45,7 +45,6 @@ def build_attn_metadata( | None = None, spec_attn_mask: torch.Tensor | None = None, attn_state: Any | None = None, - is_only_prefill: bool = False, graph_pad_size: int = -1, num_input_tokens: int = 0, prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata @@ -78,7 +77,6 @@ def build_attn_metadata( attn_mask=attn_mask, spec_attn_mask=spec_attn_mask, attn_state=attn_state, - is_only_prefill=is_only_prefill, graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, prefill_context_parallel_metadata=prefill_context_parallel_metadata, diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index 6ffac6db..c41734a4 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -247,7 +247,9 @@ class XliteWrapper: if not with_prefill or self.full_mode: batch = attn_metadata.num_prefills + attn_metadata.num_decodes seq_lens = attn_metadata.seq_lens[:batch] - query_lens = attn_metadata.query_lens[:batch] + query_lens = attn_metadata.query_start_loc_cpu[ + 1:] - attn_metadata.query_start_loc_cpu[:-1] + query_lens = query_lens[:batch] cached_lens = seq_lens - query_lens xlite_attn_metadata = ModelAttnMeta()