[Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Remove the pcp-related code from attention_v1.
2. Establish the inheritance relationship of CommonAttentionMetadata.
TODO
1. extract common_cp
2. move cp metadata to common_cp.
3. remove commonAttentionMetadata for aclgraph.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -117,7 +117,8 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
mock_model = MagicMock()
|
||||
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@@ -606,7 +606,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 10,
|
||||
@@ -673,7 +674,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 15,
|
||||
@@ -729,7 +731,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 3,
|
||||
@@ -784,7 +787,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 3,
|
||||
@@ -839,7 +843,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
seq_lens=None,
|
||||
max_seq_len=6)
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
|
||||
@@ -33,10 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendMetadata,
|
||||
AscendMetadataForDecode,
|
||||
AscendMetadataForPrefill)
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendMetadataForDecode,
|
||||
AscendMetadataForPrefill,
|
||||
filter_chunked_req_indices,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
|
||||
@@ -34,7 +34,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
enable_cp, split_decodes_and_prefills,
|
||||
AscendMetadataForDecode,
|
||||
AscendMetadataForPrefill, enable_cp,
|
||||
split_decodes_and_prefills,
|
||||
using_paged_attention)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
@@ -118,51 +120,6 @@ class AscendAttentionState(Enum):
|
||||
SpecDecoding = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
|
||||
@dataclass
|
||||
class AscendPCPMetadata:
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||
attn_mask_seqlens: torch.Tensor = None
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
actual_chunk_seq_lengths: torch.Tensor
|
||||
actual_seq_lengths_kv: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
# **************************** Basic Properties ************************** #
|
||||
@@ -274,14 +231,7 @@ class AscendAttentionMetadataBuilder:
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
@@ -292,7 +242,6 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||
get_ascend_device_type)
|
||||
@@ -34,6 +35,51 @@ def enable_cp():
|
||||
or prefill_config.decode_context_parallel_size > 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
|
||||
@dataclass
|
||||
class AscendPCPMetadata:
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||
attn_mask_seqlens: torch.Tensor = None
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
actual_chunk_seq_lengths: torch.Tensor
|
||||
actual_seq_lengths_kv: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
# class AscendCommonLongSequenceMetadata:
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
@@ -75,45 +121,20 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both NPU and CPU versions.
|
||||
"""
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
num_computed_tokens_cpu: torch.Tensor = None
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
||||
(such as GDN)."""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
max_query_len: int
|
||||
"""Max token number of request in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
decode_token_per_req: int = 1
|
||||
"""decode token number per request"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
@@ -132,8 +153,6 @@ class AscendCommonAttentionMetadata:
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
@@ -161,7 +180,7 @@ class AscendCommonAttentionMetadata:
|
||||
num_input_tokens=num_actual_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
prefill_context_parallel_metadata,
|
||||
)
|
||||
max_seq_len=self.max_seq_len)
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
|
||||
@@ -742,7 +742,7 @@ class EagleProposer(Proposer):
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
max_seq_len=0)
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
def prepare_inputs_padded(
|
||||
@@ -800,7 +800,8 @@ class EagleProposer(Proposer):
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
seq_lens=common_attn_metadata.seq_lens)
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
max_seq_len=0)
|
||||
|
||||
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
||||
1 - num_rejected_tokens_gpu)
|
||||
|
||||
@@ -268,7 +268,7 @@ class MtpProposer(Proposer):
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
max_seq_len=0)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||
@@ -599,7 +599,7 @@ class MtpProposer(Proposer):
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
max_seq_len=0)
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
def _propose(
|
||||
@@ -1221,7 +1221,8 @@ class MtpProposer(Proposer):
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
seq_lens=common_attn_metadata.seq_lens)
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
max_seq_len=0)
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc[
|
||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||
|
||||
@@ -1045,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
max_seq_len=0)
|
||||
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
@@ -1874,7 +1874,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
max_seq_len=0)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
block_table_tensor[:num_reqs * self.decode_threshold]
|
||||
|
||||
@@ -53,6 +53,7 @@ def build_attn_metadata(
|
||||
"""Build attention metadata for Ascend NPUs."""
|
||||
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
||||
max_query_len = int(query_start_loc_cpu.max())
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
@@ -80,7 +81,7 @@ def build_attn_metadata(
|
||||
graph_pad_size=graph_pad_size,
|
||||
num_input_tokens=num_input_tokens,
|
||||
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
||||
)
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
|
||||
Reference in New Issue
Block a user