[Refactor] Add comments for Metadata classes in attention module (#5789)
### What this PR does / why we need it? Add docstrings for Metadata and MetadataBuilder classes in the attention module to improve code readability. Related to #5463 (Item 11: Add some comments for CommonMetadata and others) **Modified files:** - `vllm_ascend/attention/context_parallel/common_cp.py`: Added comments for `AscendPCPMetadata`, `CPChunkedContextMetadata`, `AscendMetadataForPrefill`, `AscendMetadataForDecode` - `vllm_ascend/attention/utils.py`: Added comments for `AscendPrefillContextParallelMetadata` - `vllm_ascend/attention/mla_v1.py`: Added comments for `ChunkedContextMetadata`, `AscendMLADecodeMetadata` - `vllm_ascend/attention/attention_v1.py`: Added comments for `AscendMetadata`, `AscendAttentionMetadataBuilder` - `vllm_ascend/attention/context_parallel/attention_cp.py`: Added comments for `AscendAttentionCPMetadataBuilder` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Documentation only, no functional changes. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -36,8 +36,12 @@ def enable_cp():
|
||||
|
||||
|
||||
@dataclass
|
||||
# class AscendCommonLongSequenceMetadata:
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
"""
|
||||
Metadata for Prefill Context Parallelism (PCP) in CommonAttentionMetadata.
|
||||
|
||||
Contains index tensors and sequence lengths for PCP operations.
|
||||
"""
|
||||
pcp_allgather_restore_idx: torch.Tensor = None
|
||||
|
||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||
@@ -81,24 +85,36 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
For many of the tensors we keep both NPU and CPU versions.
|
||||
"""
|
||||
# CPU tensor of sequence lengths for host-side operations.
|
||||
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
|
||||
# CPU tensor of already computed tokens count per request.
|
||||
# E.g., tensor([100, 200, 50]) means req0 has 100 tokens already computed.
|
||||
num_computed_tokens_cpu: torch.Tensor = None
|
||||
|
||||
# Number of decode tokens per request, used for speculative decoding.
|
||||
# E.g., 1 for normal decoding, >1 for speculative decoding.
|
||||
decode_token_per_req: int = 1
|
||||
"""decode token number per request"""
|
||||
|
||||
# Actual query sequence lengths for each token in the batch (CPU list).
|
||||
# E.g., [1, 1, 1, 128] for 3 decode tokens and 1 prefill with 128 tokens.
|
||||
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
||||
|
||||
# NPU tensor of position indices for rotary embeddings computation.
|
||||
# E.g., tensor([0, 1, 2, ...]) indicating token positions in sequence.
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# Current attention state (e.g., ChunkedPrefill, DecodeOnly).
|
||||
attn_state: Any = None
|
||||
|
||||
# Padding size for graph capture, -1 means not in graph mode.
|
||||
graph_pad_size: int = -1
|
||||
|
||||
# num_input_tokens refers to total number of tokens including
|
||||
# padding tokens. It is used to handle some padding operations.
|
||||
# Total number of tokens including padding, used for padding operations.
|
||||
num_input_tokens: int = 0
|
||||
|
||||
# Metadata for Prefill Context Parallelism (PCP) operations.
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user