From 95e8a52156073b901613e7d2d6052f4aed0fdc5c Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Tue, 23 Dec 2025 00:10:52 +0800 Subject: [PATCH] [Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 1. Remove the pcp-related code from attention_v1. 2. Establish the inheritance relationship of CommonAttentionMetadata. TODO 1. extract common_cp 2. move cp metadata to common_cp. 3. remove commonAttentionMetadata for aclgraph. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- tests/ut/attention/test_attention_v1.py | 3 +- tests/ut/attention/test_mla_v1.py | 15 ++-- vllm_ascend/attention/attention_cp.py | 6 +- vllm_ascend/attention/attention_v1.py | 59 ++------------- vllm_ascend/attention/utils.py | 87 ++++++++++++++--------- vllm_ascend/spec_decode/eagle_proposer.py | 5 +- vllm_ascend/spec_decode/mtp_proposer.py | 7 +- vllm_ascend/worker/model_runner_v1.py | 4 +- vllm_ascend/worker/v2/attn_utils.py | 3 +- 9 files changed, 83 insertions(+), 106 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index f78445cf..5746099c 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -117,7 +117,8 @@ class TestAscendAttentionMetadataBuilder(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.ChunkedPrefill, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) mock_model = MagicMock() self.builder.build(1, common_attn_metadata, mock_model) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 3ee8116c..b07a9a55 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -606,7 +606,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.PrefillNoCache, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) base_inputs = { "num_actual_tokens": 10, @@ -673,7 +674,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.ChunkedPrefill, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) base_inputs = { "num_actual_tokens": 15, @@ -729,7 +731,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.DecodeOnly, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) base_inputs = { "num_actual_tokens": 3, @@ -784,7 +787,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.DecodeOnly, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) base_inputs = { "num_actual_tokens": 3, @@ -839,7 +843,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): spec_attn_mask=None, attn_state=AscendAttentionState.PrefillNoCache, num_computed_tokens_cpu=None, - seq_lens=None) + seq_lens=None, + max_seq_len=6) builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, layer_names=["layer_0", "layer_1"], diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 86a71cae..22c58369 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -33,10 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl, AscendAttentionMetadataBuilder, - AscendMetadata, - AscendMetadataForDecode, - AscendMetadataForPrefill) + AscendMetadata) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + AscendMetadataForDecode, + AscendMetadataForPrefill, filter_chunked_req_indices, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index f4107aa2..001d58fb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,7 +34,9 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - enable_cp, split_decodes_and_prefills, + AscendMetadataForDecode, + AscendMetadataForPrefill, enable_cp, + split_decodes_and_prefills, using_paged_attention) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) @@ -118,51 +120,6 @@ class AscendAttentionState(Enum): SpecDecoding = 4 -@dataclass -class AscendMetadataForPrefill: - - @dataclass - class AscendPCPMetadata: - q_head_idx: torch.Tensor = None - q_tail_idx: torch.Tensor = None - kv_with_q_head_nomask_idx: torch.Tensor = None - kv_with_q_head_mask_idx: torch.Tensor = None - kv_with_q_tail_nomask_idx: torch.Tensor = None - kv_with_q_tail_mask_idx: torch.Tensor = None - attn_mask_seqlens: torch.Tensor = None - head_attn_nomask_seqlens: torch.Tensor = None - tail_attn_nomask_seqlens: torch.Tensor = None - q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None - - @dataclass - class ChunkedContextMetadata: - actual_chunk_seq_lengths: torch.Tensor - actual_seq_lengths_kv: torch.Tensor - starts: torch.Tensor - chunk_seq_mask_filtered_indices: torch.Tensor - chunked_req_mask: Optional[list[bool]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None - cp_kv_recover_idx_for_chunk: Optional[list[int]] = None - kv_inverse_idx_for_chunk: Optional[list[int]] = None - batch_chunk_seq_mask: Optional[list[bool]] = None - - """ Prefill Specific Metadata for Ascend""" - pcp_metadata: Optional[AscendPCPMetadata] = None - pcp_allgather_restore_idx: Optional[List[int]] = None - chunked_context: Optional[ChunkedContextMetadata] = None - block_tables: torch.Tensor = None - actual_seq_lengths_q: torch.Tensor = None - - -@dataclass -class AscendMetadataForDecode: - """ Decode Specific Metadata for Ascend""" - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - batch_seq_mask: torch.Tensor = None - block_tables: torch.Tensor = None - - @dataclass class AscendMetadata: # **************************** Basic Properties ************************** # @@ -274,14 +231,7 @@ class AscendAttentionMetadataBuilder: block_table = common_attn_metadata.block_table_tensor seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None - if num_actual_tokens_pcp_padded is None: - num_actual_tokens_pcp_padded = num_actual_tokens - - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens_pcp_padded] - + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state @@ -292,7 +242,6 @@ class AscendAttentionMetadataBuilder: attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, num_decode_tokens=num_decode_tokens, - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, seq_lens=seq_lens, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 81c92b37..40175e5c 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from typing import Any, List, Optional @@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, get_ascend_device_type) @@ -34,6 +35,51 @@ def enable_cp(): or prefill_config.decode_context_parallel_size > 1 +@dataclass +class AscendMetadataForPrefill: + + @dataclass + class AscendPCPMetadata: + q_head_idx: torch.Tensor = None + q_tail_idx: torch.Tensor = None + kv_with_q_head_nomask_idx: torch.Tensor = None + kv_with_q_head_mask_idx: torch.Tensor = None + kv_with_q_tail_nomask_idx: torch.Tensor = None + kv_with_q_tail_mask_idx: torch.Tensor = None + attn_mask_seqlens: torch.Tensor = None + head_attn_nomask_seqlens: torch.Tensor = None + tail_attn_nomask_seqlens: torch.Tensor = None + q_full_idx: torch.Tensor = None + pcp_prefill_mask: torch.Tensor = None + + @dataclass + class ChunkedContextMetadata: + actual_chunk_seq_lengths: torch.Tensor + actual_seq_lengths_kv: torch.Tensor + starts: torch.Tensor + chunk_seq_mask_filtered_indices: torch.Tensor + chunked_req_mask: Optional[list[bool]] = None + local_context_lens_allranks: Optional[list[list[int]]] = None + cp_kv_recover_idx_for_chunk: Optional[list[int]] = None + kv_inverse_idx_for_chunk: Optional[list[int]] = None + batch_chunk_seq_mask: Optional[list[bool]] = None + + """ Prefill Specific Metadata for Ascend""" + pcp_metadata: Optional[AscendPCPMetadata] = None + pcp_allgather_restore_idx: Optional[List[int]] = None + chunked_context: Optional[ChunkedContextMetadata] = None + block_tables: torch.Tensor = None + actual_seq_lengths_q: torch.Tensor = None + + +@dataclass +class AscendMetadataForDecode: + """ Decode Specific Metadata for Ascend""" + num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + batch_seq_mask: torch.Tensor = None + block_tables: torch.Tensor = None + + @dataclass # class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: @@ -75,45 +121,20 @@ class AscendPrefillContextParallelMetadata: @dataclass -class AscendCommonAttentionMetadata: +class AscendCommonAttentionMetadata(CommonAttentionMetadata): """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. For many of the tensors we keep both NPU and CPU versions. """ + seq_lens_cpu: torch.Tensor = None + num_computed_tokens_cpu: torch.Tensor = None - query_start_loc: torch.Tensor - query_start_loc_cpu: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - - seq_lens_cpu: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - seq_lens: torch.Tensor - """same to seq_lens_cpu, for compatibility with some new attn metadata - (such as GDN).""" - - num_computed_tokens_cpu: torch.Tensor - """(batch_size,), the number of computed tokens for each request""" - - num_reqs: int - """Number of requests""" - num_actual_tokens: int - """Total number of tokens in batch""" - - max_query_len: int - """Max token number of request in batch""" - - decode_token_per_req: int + decode_token_per_req: int = 1 """decode token number per request""" - block_table_tensor: torch.Tensor - - slot_mapping: torch.Tensor - - actual_seq_lengths_q: list[int] + actual_seq_lengths_q: list[int] = field(default_factory=list) positions: torch.Tensor = None @@ -132,8 +153,6 @@ class AscendCommonAttentionMetadata: prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None - causal: bool = True - # TODO: Remove it when vLLM no longer uses this function. def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata": @@ -161,7 +180,7 @@ class AscendCommonAttentionMetadata: num_input_tokens=num_actual_tokens, prefill_context_parallel_metadata=self. prefill_context_parallel_metadata, - ) + max_seq_len=self.max_seq_len) def filter_chunked_req_indices( diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 7c4f2e3b..a231246e 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -742,7 +742,7 @@ class EagleProposer(Proposer): spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - ) + max_seq_len=0) return spec_common_attn_metadata, token_indices def prepare_inputs_padded( @@ -800,7 +800,8 @@ class EagleProposer(Proposer): decode_token_per_req=self.runner.decode_token_per_req, num_computed_tokens_cpu=common_attn_metadata. num_computed_tokens_cpu, - seq_lens=common_attn_metadata.seq_lens) + seq_lens=common_attn_metadata.seq_lens, + max_seq_len=0) token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index f17cea10..9f71183b 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -268,7 +268,7 @@ class MtpProposer(Proposer): spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - ) + max_seq_len=0) if self.pcp_size * self.dcp_size > 1: # update long_seq related params and flatten block_table common_attn_metadata.prefill_context_parallel_metadata = \ @@ -599,7 +599,7 @@ class MtpProposer(Proposer): spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - ) + max_seq_len=0) return spec_common_attn_metadata, token_indices def _propose( @@ -1221,7 +1221,8 @@ class MtpProposer(Proposer): decode_token_per_req=self.runner.decode_token_per_req, num_computed_tokens_cpu=common_attn_metadata. num_computed_tokens_cpu, - seq_lens=common_attn_metadata.seq_lens) + seq_lens=common_attn_metadata.seq_lens, + max_seq_len=0) query_start_loc = common_attn_metadata.query_start_loc[ 1:1 + num_rejected_tokens_gpu.shape[0]] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b2c48978..82fc6861 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1045,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner): max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=long_seq_metadata, - ) + max_seq_len=0) if self.speculative_config and self.pcp_size * self.dcp_size > 1: # For pcp + spec decode, we flatten block_table @@ -1874,7 +1874,7 @@ class NPUModelRunner(GPUModelRunner): max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=long_seq_metadata, - ) + max_seq_len=0) if self.pcp_size * self.dcp_size > 1: common_attn_metadata.block_table_tensor = \ block_table_tensor[:num_reqs * self.decode_threshold] diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index eb536eaa..a66e5a21 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -53,6 +53,7 @@ def build_attn_metadata( """Build attention metadata for Ascend NPUs.""" # TODO(Ronald1995): optimize AscendCommonAttentionMetadata. max_query_len = int(query_start_loc_cpu.max()) + max_seq_len = int(seq_lens_cpu.max()) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups @@ -80,7 +81,7 @@ def build_attn_metadata( graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, prefill_context_parallel_metadata=prefill_context_parallel_metadata, - ) + max_seq_len=max_seq_len) attn_metadata_builder = attn_metadata_builders[i] metadata = attn_metadata_builder.build(