[Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

1. Remove the pcp-related code from attention_v1.
2. Establish the inheritance relationship of CommonAttentionMetadata.

TODO
1. extract common_cp
2. move cp metadata to common_cp.
3. remove commonAttentionMetadata for aclgraph.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-23 00:10:52 +08:00
committed by GitHub
parent 3d9954eff0
commit 95e8a52156
9 changed files with 83 additions and 106 deletions

View File

@@ -117,7 +117,8 @@ class TestAscendAttentionMetadataBuilder(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
mock_model = MagicMock()
self.builder.build(1, common_attn_metadata, mock_model)

View File

@@ -606,7 +606,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
base_inputs = {
"num_actual_tokens": 10,
@@ -673,7 +674,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
base_inputs = {
"num_actual_tokens": 15,
@@ -729,7 +731,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
base_inputs = {
"num_actual_tokens": 3,
@@ -784,7 +787,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
base_inputs = {
"num_actual_tokens": 3,
@@ -839,7 +843,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)
seq_lens=None,
max_seq_len=6)
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],

View File

@@ -33,10 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendMetadata,
AscendMetadataForDecode,
AscendMetadataForPrefill)
AscendMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendMetadataForDecode,
AscendMetadataForPrefill,
filter_chunked_req_indices,
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,

View File

@@ -34,7 +34,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp, split_decodes_and_prefills,
AscendMetadataForDecode,
AscendMetadataForPrefill, enable_cp,
split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
@@ -118,51 +120,6 @@ class AscendAttentionState(Enum):
SpecDecoding = 4
@dataclass
class AscendMetadataForPrefill:
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
batch_chunk_seq_mask: Optional[list[bool]] = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
chunked_context: Optional[ChunkedContextMetadata] = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
@dataclass
class AscendMetadata:
# **************************** Basic Properties ************************** #
@@ -274,14 +231,7 @@ class AscendAttentionMetadataBuilder:
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
@@ -292,7 +242,6 @@ class AscendAttentionMetadataBuilder:
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, List, Optional
@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
get_ascend_device_type)
@@ -34,6 +35,51 @@ def enable_cp():
or prefill_config.decode_context_parallel_size > 1
@dataclass
class AscendMetadataForPrefill:
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
batch_chunk_seq_mask: Optional[list[bool]] = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
chunked_context: Optional[ChunkedContextMetadata] = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
@@ -75,45 +121,20 @@ class AscendPrefillContextParallelMetadata:
@dataclass
class AscendCommonAttentionMetadata:
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both NPU and CPU versions.
"""
seq_lens_cpu: torch.Tensor = None
num_computed_tokens_cpu: torch.Tensor = None
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
seq_lens: torch.Tensor
"""same to seq_lens_cpu, for compatibility with some new attn metadata
(such as GDN)."""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Max token number of request in batch"""
decode_token_per_req: int
decode_token_per_req: int = 1
"""decode token number per request"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
actual_seq_lengths_q: list[int]
actual_seq_lengths_q: list[int] = field(default_factory=list)
positions: torch.Tensor = None
@@ -132,8 +153,6 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
causal: bool = True
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
@@ -161,7 +180,7 @@ class AscendCommonAttentionMetadata:
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.
prefill_context_parallel_metadata,
)
max_seq_len=self.max_seq_len)
def filter_chunked_req_indices(

View File

@@ -742,7 +742,7 @@ class EagleProposer(Proposer):
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
max_seq_len=0)
return spec_common_attn_metadata, token_indices
def prepare_inputs_padded(
@@ -800,7 +800,8 @@ class EagleProposer(Proposer):
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
seq_lens=common_attn_metadata.seq_lens)
seq_lens=common_attn_metadata.seq_lens,
max_seq_len=0)
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
1 - num_rejected_tokens_gpu)

View File

@@ -268,7 +268,7 @@ class MtpProposer(Proposer):
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
max_seq_len=0)
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = \
@@ -599,7 +599,7 @@ class MtpProposer(Proposer):
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
max_seq_len=0)
return spec_common_attn_metadata, token_indices
def _propose(
@@ -1221,7 +1221,8 @@ class MtpProposer(Proposer):
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
seq_lens=common_attn_metadata.seq_lens)
seq_lens=common_attn_metadata.seq_lens,
max_seq_len=0)
query_start_loc = common_attn_metadata.query_start_loc[
1:1 + num_rejected_tokens_gpu.shape[0]]

View File

@@ -1045,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
)
max_seq_len=0)
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
@@ -1874,7 +1874,7 @@ class NPUModelRunner(GPUModelRunner):
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
)
max_seq_len=0)
if self.pcp_size * self.dcp_size > 1:
common_attn_metadata.block_table_tensor = \
block_table_tensor[:num_reqs * self.decode_threshold]

View File

@@ -53,6 +53,7 @@ def build_attn_metadata(
"""Build attention metadata for Ascend NPUs."""
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
max_query_len = int(query_start_loc_cpu.max())
max_seq_len = int(seq_lens_cpu.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -80,7 +81,7 @@ def build_attn_metadata(
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
)
max_seq_len=max_seq_len)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(