diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index fa2a528..1551565 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -121,28 +121,40 @@ class AscendAttentionState(Enum): @dataclass class AscendMetadata: - num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - block_tables: torch.Tensor - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - query_start_loc: torch.Tensor - query_lens: torch.Tensor - seq_lens: torch.Tensor - # max value of number of tokens across dp group - max_num_tokens_across_dp: int = 0 - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor = None + # **************************** Basic Properties **************************** + attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - attn_mask: Optional[torch.Tensor] = None + + # Number of tokens excluding padding. + num_actual_tokens: int = 0 + + # The sequence length per sequence. Sequence length means the computed + # tokens + new tokens (is None if it is a decoding). + # (batch_size,) + seq_lens: torch.Tensor = None + + query_start_loc: torch.Tensor = None + query_lens: torch.Tensor = None + # Maximum query length in the batch (None for decoding). + max_query_len: Optional[int] = None + + # ********************** KV Cache Related Properties *********************** + # Block addresses per sequence (Seq id -> list of physical block). + # (batch_size, max_blocks_per_seq) + block_tables: torch.Tensor = None + + # The indices of the token slots that input tokens will be stored into. + # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the + # three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0, + # and 1st slot in block 1, respectively. + # (num_tokens,) + slot_mapping: torch.Tensor = None + + # ************************* DP Related Properties ************************** with_prefill_across_dp: bool = False + # Maximum number of tokens across dp group + max_num_tokens_across_dp: int = 0 class AscendAttentionMetadataBuilder: