From c8a324ab73478043528b772d588ca6ffa2ec091f Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Tue, 13 Jan 2026 08:46:50 +0800 Subject: [PATCH] [Refactor] Add comments for Metadata classes in attention module (#5789) ### What this PR does / why we need it? Add docstrings for Metadata and MetadataBuilder classes in the attention module to improve code readability. Related to #5463 (Item 11: Add some comments for CommonMetadata and others) **Modified files:** - `vllm_ascend/attention/context_parallel/common_cp.py`: Added comments for `AscendPCPMetadata`, `CPChunkedContextMetadata`, `AscendMetadataForPrefill`, `AscendMetadataForDecode` - `vllm_ascend/attention/utils.py`: Added comments for `AscendPrefillContextParallelMetadata` - `vllm_ascend/attention/mla_v1.py`: Added comments for `ChunkedContextMetadata`, `AscendMLADecodeMetadata` - `vllm_ascend/attention/attention_v1.py`: Added comments for `AscendMetadata`, `AscendAttentionMetadataBuilder` - `vllm_ascend/attention/context_parallel/attention_cp.py`: Added comments for `AscendAttentionCPMetadataBuilder` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Documentation only, no functional changes. Signed-off-by: lico67373 <918688502@qq.com> --- vllm_ascend/attention/attention_v1.py | 13 +++++++++- .../context_parallel/attention_cp.py | 6 ++++- .../attention/context_parallel/common_cp.py | 16 +++++++++++-- vllm_ascend/attention/mla_v1.py | 10 +++++--- vllm_ascend/attention/utils.py | 24 +++++++++++++++---- 5 files changed, 58 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6a016366..8dc013f3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -132,6 +132,12 @@ class AscendAttentionState(Enum): @dataclass class AscendMetadata: + """ + Per-layer attention metadata for Ascend FlashAttention backend. + + Contains attention masks, token counts, sequence lengths and KV cache + related properties for attention computation. + """ # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. @@ -186,7 +192,12 @@ class AscendMetadata: class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): - # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + """ + Builder for constructing AscendMetadata from CommonAttentionMetadata. + + Handles attention mask generation and metadata preparation for + Ascend FlashAttention backend. + """ # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index affcf643..088cd0e4 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -45,7 +45,11 @@ from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): - # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + """ + Builder for constructing AscendMetadata with Context Parallelism support. + + Extends AscendAttentionMetadataBuilder with PCP/DCP metadata handling. + """ # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 018919c0..89debf11 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -11,6 +11,12 @@ from vllm.distributed import (get_dcp_group, @dataclass class AscendPCPMetadata: + """ + Metadata for Prefill Context Parallelism (PCP) on Ascend devices. + + Stores index tensors and sequence lengths for routing attention + computations across PCP ranks during long sequence processing. + """ q_head_idx: torch.Tensor = None q_tail_idx: torch.Tensor = None kv_with_q_head_nomask_idx: torch.Tensor = None @@ -26,7 +32,11 @@ class AscendPCPMetadata: @dataclass class CPChunkedContextMetadata: - # New for MLA (compared to FlashAttention) + """ + Metadata for chunked context handling in Context Parallelism (CP). + + Extends chunked prefill with per-rank chunk information for PCP/DCP. + """ # For handling chunked prefill cu_seq_lens: torch.Tensor starts: torch.Tensor @@ -46,9 +56,11 @@ class CPChunkedContextMetadata: @dataclass class AscendMetadataForPrefill: + """ Prefill-specific metadata for Ascend attention with Context Parallelism.""" @dataclass class ChunkedContextMetadata: + """Metadata for chunked context processing within prefill phase.""" actual_chunk_seq_lengths: torch.Tensor actual_seq_lengths_kv: torch.Tensor starts: torch.Tensor @@ -69,7 +81,7 @@ class AscendMetadataForPrefill: @dataclass class AscendMetadataForDecode: - """ Decode Specific Metadata for Ascend""" + """ Decode-specific metadata for Ascend attention with Context Parallelism.""" num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None batch_seq_mask: torch.Tensor = None block_tables: torch.Tensor = None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 38cc7fd3..c98e3d5a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -84,8 +84,11 @@ class AscendMLABackend(AttentionBackend): @dataclass class ChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill + """ + Metadata for chunked context handling in MLA attention. + + Manages sequence boundaries and workspace for chunked prefill processing. + """ cu_seq_lens: torch.Tensor starts: torch.Tensor seq_tot: list[int] @@ -116,7 +119,8 @@ class AscendMLAPrefillMetadata: @dataclass class AscendMLADecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary + """ Decode-specific metadata for Ascend MLA attention.""" + # Input positions for rotary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor block_table: torch.Tensor diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index be073c46..826c91a5 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -36,8 +36,12 @@ def enable_cp(): @dataclass -# class AscendCommonLongSequenceMetadata: class AscendPrefillContextParallelMetadata: + """ + Metadata for Prefill Context Parallelism (PCP) in CommonAttentionMetadata. + + Contains index tensors and sequence lengths for PCP operations. + """ pcp_allgather_restore_idx: torch.Tensor = None cp_kv_recover_idx_for_chunk: torch.Tensor = None @@ -81,24 +85,36 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): For many of the tensors we keep both NPU and CPU versions. """ + # CPU tensor of sequence lengths for host-side operations. + # E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths. seq_lens_cpu: torch.Tensor = None + + # CPU tensor of already computed tokens count per request. + # E.g., tensor([100, 200, 50]) means req0 has 100 tokens already computed. num_computed_tokens_cpu: torch.Tensor = None + # Number of decode tokens per request, used for speculative decoding. + # E.g., 1 for normal decoding, >1 for speculative decoding. decode_token_per_req: int = 1 - """decode token number per request""" + # Actual query sequence lengths for each token in the batch (CPU list). + # E.g., [1, 1, 1, 128] for 3 decode tokens and 1 prefill with 128 tokens. actual_seq_lengths_q: list[int] = field(default_factory=list) + # NPU tensor of position indices for rotary embeddings computation. + # E.g., tensor([0, 1, 2, ...]) indicating token positions in sequence. positions: torch.Tensor = None + # Current attention state (e.g., ChunkedPrefill, DecodeOnly). attn_state: Any = None + # Padding size for graph capture, -1 means not in graph mode. graph_pad_size: int = -1 - # num_input_tokens refers to total number of tokens including - # padding tokens. It is used to handle some padding operations. + # Total number of tokens including padding, used for padding operations. num_input_tokens: int = 0 + # Metadata for Prefill Context Parallelism (PCP) operations. prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None