[Refactor] Add comments for Metadata classes in attention module (#5789)
### What this PR does / why we need it? Add docstrings for Metadata and MetadataBuilder classes in the attention module to improve code readability. Related to #5463 (Item 11: Add some comments for CommonMetadata and others) **Modified files:** - `vllm_ascend/attention/context_parallel/common_cp.py`: Added comments for `AscendPCPMetadata`, `CPChunkedContextMetadata`, `AscendMetadataForPrefill`, `AscendMetadataForDecode` - `vllm_ascend/attention/utils.py`: Added comments for `AscendPrefillContextParallelMetadata` - `vllm_ascend/attention/mla_v1.py`: Added comments for `ChunkedContextMetadata`, `AscendMLADecodeMetadata` - `vllm_ascend/attention/attention_v1.py`: Added comments for `AscendMetadata`, `AscendAttentionMetadataBuilder` - `vllm_ascend/attention/context_parallel/attention_cp.py`: Added comments for `AscendAttentionCPMetadataBuilder` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Documentation only, no functional changes. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -132,6 +132,12 @@ class AscendAttentionState(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
|
"""
|
||||||
|
Per-layer attention metadata for Ascend FlashAttention backend.
|
||||||
|
|
||||||
|
Contains attention masks, token counts, sequence lengths and KV cache
|
||||||
|
related properties for attention computation.
|
||||||
|
"""
|
||||||
# **************************** Basic Properties ************************** #
|
# **************************** Basic Properties ************************** #
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
@@ -186,7 +192,12 @@ class AscendMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
"""
|
||||||
|
Builder for constructing AscendMetadata from CommonAttentionMetadata.
|
||||||
|
|
||||||
|
Handles attention mask generation and metadata preparation for
|
||||||
|
Ascend FlashAttention backend.
|
||||||
|
"""
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
# length that will be pulled into the front of the batch.
|
# length that will be pulled into the front of the batch.
|
||||||
|
|||||||
@@ -45,7 +45,11 @@ from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors
|
|||||||
|
|
||||||
|
|
||||||
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
"""
|
||||||
|
Builder for constructing AscendMetadata with Context Parallelism support.
|
||||||
|
|
||||||
|
Extends AscendAttentionMetadataBuilder with PCP/DCP metadata handling.
|
||||||
|
"""
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
# length that will be pulled into the front of the batch.
|
# length that will be pulled into the front of the batch.
|
||||||
|
|||||||
@@ -11,6 +11,12 @@ from vllm.distributed import (get_dcp_group,
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendPCPMetadata:
|
class AscendPCPMetadata:
|
||||||
|
"""
|
||||||
|
Metadata for Prefill Context Parallelism (PCP) on Ascend devices.
|
||||||
|
|
||||||
|
Stores index tensors and sequence lengths for routing attention
|
||||||
|
computations across PCP ranks during long sequence processing.
|
||||||
|
"""
|
||||||
q_head_idx: torch.Tensor = None
|
q_head_idx: torch.Tensor = None
|
||||||
q_tail_idx: torch.Tensor = None
|
q_tail_idx: torch.Tensor = None
|
||||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||||
@@ -26,7 +32,11 @@ class AscendPCPMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CPChunkedContextMetadata:
|
class CPChunkedContextMetadata:
|
||||||
# New for MLA (compared to FlashAttention)
|
"""
|
||||||
|
Metadata for chunked context handling in Context Parallelism (CP).
|
||||||
|
|
||||||
|
Extends chunked prefill with per-rank chunk information for PCP/DCP.
|
||||||
|
"""
|
||||||
# For handling chunked prefill
|
# For handling chunked prefill
|
||||||
cu_seq_lens: torch.Tensor
|
cu_seq_lens: torch.Tensor
|
||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
@@ -46,9 +56,11 @@ class CPChunkedContextMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadataForPrefill:
|
class AscendMetadataForPrefill:
|
||||||
|
""" Prefill-specific metadata for Ascend attention with Context Parallelism."""
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChunkedContextMetadata:
|
class ChunkedContextMetadata:
|
||||||
|
"""Metadata for chunked context processing within prefill phase."""
|
||||||
actual_chunk_seq_lengths: torch.Tensor
|
actual_chunk_seq_lengths: torch.Tensor
|
||||||
actual_seq_lengths_kv: torch.Tensor
|
actual_seq_lengths_kv: torch.Tensor
|
||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
@@ -69,7 +81,7 @@ class AscendMetadataForPrefill:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadataForDecode:
|
class AscendMetadataForDecode:
|
||||||
""" Decode Specific Metadata for Ascend"""
|
""" Decode-specific metadata for Ascend attention with Context Parallelism."""
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
batch_seq_mask: torch.Tensor = None
|
batch_seq_mask: torch.Tensor = None
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
|
|||||||
@@ -84,8 +84,11 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChunkedContextMetadata:
|
class ChunkedContextMetadata:
|
||||||
# New for MLA (compared to FlashAttention)
|
"""
|
||||||
# For handling chunked prefill
|
Metadata for chunked context handling in MLA attention.
|
||||||
|
|
||||||
|
Manages sequence boundaries and workspace for chunked prefill processing.
|
||||||
|
"""
|
||||||
cu_seq_lens: torch.Tensor
|
cu_seq_lens: torch.Tensor
|
||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
seq_tot: list[int]
|
seq_tot: list[int]
|
||||||
@@ -116,7 +119,8 @@ class AscendMLAPrefillMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMLADecodeMetadata:
|
class AscendMLADecodeMetadata:
|
||||||
# Input positions for rotrary embeddings since for MLA the rotary
|
""" Decode-specific metadata for Ascend MLA attention."""
|
||||||
|
# Input positions for rotary embeddings since for MLA the rotary
|
||||||
# position embeddings are applied inside the attention backend
|
# position embeddings are applied inside the attention backend
|
||||||
input_positions: torch.Tensor
|
input_positions: torch.Tensor
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
|
|||||||
@@ -36,8 +36,12 @@ def enable_cp():
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# class AscendCommonLongSequenceMetadata:
|
|
||||||
class AscendPrefillContextParallelMetadata:
|
class AscendPrefillContextParallelMetadata:
|
||||||
|
"""
|
||||||
|
Metadata for Prefill Context Parallelism (PCP) in CommonAttentionMetadata.
|
||||||
|
|
||||||
|
Contains index tensors and sequence lengths for PCP operations.
|
||||||
|
"""
|
||||||
pcp_allgather_restore_idx: torch.Tensor = None
|
pcp_allgather_restore_idx: torch.Tensor = None
|
||||||
|
|
||||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||||
@@ -81,24 +85,36 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
|||||||
|
|
||||||
For many of the tensors we keep both NPU and CPU versions.
|
For many of the tensors we keep both NPU and CPU versions.
|
||||||
"""
|
"""
|
||||||
|
# CPU tensor of sequence lengths for host-side operations.
|
||||||
|
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
|
||||||
seq_lens_cpu: torch.Tensor = None
|
seq_lens_cpu: torch.Tensor = None
|
||||||
|
|
||||||
|
# CPU tensor of already computed tokens count per request.
|
||||||
|
# E.g., tensor([100, 200, 50]) means req0 has 100 tokens already computed.
|
||||||
num_computed_tokens_cpu: torch.Tensor = None
|
num_computed_tokens_cpu: torch.Tensor = None
|
||||||
|
|
||||||
|
# Number of decode tokens per request, used for speculative decoding.
|
||||||
|
# E.g., 1 for normal decoding, >1 for speculative decoding.
|
||||||
decode_token_per_req: int = 1
|
decode_token_per_req: int = 1
|
||||||
"""decode token number per request"""
|
|
||||||
|
|
||||||
|
# Actual query sequence lengths for each token in the batch (CPU list).
|
||||||
|
# E.g., [1, 1, 1, 128] for 3 decode tokens and 1 prefill with 128 tokens.
|
||||||
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# NPU tensor of position indices for rotary embeddings computation.
|
||||||
|
# E.g., tensor([0, 1, 2, ...]) indicating token positions in sequence.
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# Current attention state (e.g., ChunkedPrefill, DecodeOnly).
|
||||||
attn_state: Any = None
|
attn_state: Any = None
|
||||||
|
|
||||||
|
# Padding size for graph capture, -1 means not in graph mode.
|
||||||
graph_pad_size: int = -1
|
graph_pad_size: int = -1
|
||||||
|
|
||||||
# num_input_tokens refers to total number of tokens including
|
# Total number of tokens including padding, used for padding operations.
|
||||||
# padding tokens. It is used to handle some padding operations.
|
|
||||||
num_input_tokens: int = 0
|
num_input_tokens: int = 0
|
||||||
|
|
||||||
|
# Metadata for Prefill Context Parallelism (PCP) operations.
|
||||||
prefill_context_parallel_metadata: Optional[
|
prefill_context_parallel_metadata: Optional[
|
||||||
AscendPrefillContextParallelMetadata] = None
|
AscendPrefillContextParallelMetadata] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user