[Refactor] Add comments for Metadata classes in attention module (#5789)
### What this PR does / why we need it? Add docstrings for Metadata and MetadataBuilder classes in the attention module to improve code readability. Related to #5463 (Item 11: Add some comments for CommonMetadata and others) **Modified files:** - `vllm_ascend/attention/context_parallel/common_cp.py`: Added comments for `AscendPCPMetadata`, `CPChunkedContextMetadata`, `AscendMetadataForPrefill`, `AscendMetadataForDecode` - `vllm_ascend/attention/utils.py`: Added comments for `AscendPrefillContextParallelMetadata` - `vllm_ascend/attention/mla_v1.py`: Added comments for `ChunkedContextMetadata`, `AscendMLADecodeMetadata` - `vllm_ascend/attention/attention_v1.py`: Added comments for `AscendMetadata`, `AscendAttentionMetadataBuilder` - `vllm_ascend/attention/context_parallel/attention_cp.py`: Added comments for `AscendAttentionCPMetadataBuilder` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Documentation only, no functional changes. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -45,7 +45,11 @@ from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors
|
||||
|
||||
|
||||
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
"""
|
||||
Builder for constructing AscendMetadata with Context Parallelism support.
|
||||
|
||||
Extends AscendAttentionMetadataBuilder with PCP/DCP metadata handling.
|
||||
"""
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
|
||||
@@ -11,6 +11,12 @@ from vllm.distributed import (get_dcp_group,
|
||||
|
||||
@dataclass
|
||||
class AscendPCPMetadata:
|
||||
"""
|
||||
Metadata for Prefill Context Parallelism (PCP) on Ascend devices.
|
||||
|
||||
Stores index tensors and sequence lengths for routing attention
|
||||
computations across PCP ranks during long sequence processing.
|
||||
"""
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
@@ -26,7 +32,11 @@ class AscendPCPMetadata:
|
||||
|
||||
@dataclass
|
||||
class CPChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
"""
|
||||
Metadata for chunked context handling in Context Parallelism (CP).
|
||||
|
||||
Extends chunked prefill with per-rank chunk information for PCP/DCP.
|
||||
"""
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
@@ -46,9 +56,11 @@ class CPChunkedContextMetadata:
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
""" Prefill-specific metadata for Ascend attention with Context Parallelism."""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
"""Metadata for chunked context processing within prefill phase."""
|
||||
actual_chunk_seq_lengths: torch.Tensor
|
||||
actual_seq_lengths_kv: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
@@ -69,7 +81,7 @@ class AscendMetadataForPrefill:
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
""" Decode-specific metadata for Ascend attention with Context Parallelism."""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
Reference in New Issue
Block a user