[Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Remove the pcp-related code from attention_v1.
2. Establish the inheritance relationship of CommonAttentionMetadata.
TODO
1. extract common_cp
2. move cp metadata to common_cp.
3. remove commonAttentionMetadata for aclgraph.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||
get_ascend_device_type)
|
||||
@@ -34,6 +35,51 @@ def enable_cp():
|
||||
or prefill_config.decode_context_parallel_size > 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
|
||||
@dataclass
|
||||
class AscendPCPMetadata:
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||
attn_mask_seqlens: torch.Tensor = None
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
actual_chunk_seq_lengths: torch.Tensor
|
||||
actual_seq_lengths_kv: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
# class AscendCommonLongSequenceMetadata:
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
@@ -75,45 +121,20 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both NPU and CPU versions.
|
||||
"""
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
num_computed_tokens_cpu: torch.Tensor = None
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
||||
(such as GDN)."""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
max_query_len: int
|
||||
"""Max token number of request in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
decode_token_per_req: int = 1
|
||||
"""decode token number per request"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
@@ -132,8 +153,6 @@ class AscendCommonAttentionMetadata:
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
@@ -161,7 +180,7 @@ class AscendCommonAttentionMetadata:
|
||||
num_input_tokens=num_actual_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
prefill_context_parallel_metadata,
|
||||
)
|
||||
max_seq_len=self.max_seq_len)
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
|
||||
Reference in New Issue
Block a user