From 55d0790597af72d996b9c1ad5a6592f334c646b9 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Thu, 14 Aug 2025 09:32:41 +0800 Subject: [PATCH] [2/N][Refactor] Refactor V1 attention for better extensibility (#1995) ### What this PR does / why we need it? Refactor V1 Attention for better extensibility (prepared for torchair attention refactor). **Main changes:** - Move different kinds of foward into their method respectively, e.g., `_forward_prefill_no_cache()`, `_forward_prefill_cache_hit()`, `_forward_decode_only()`, `_forward_v1_style()`. ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/14a5d903ab826b723a24a2d89631006394de76a1 Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/attention/attention_v1.py | 252 +++++++++++++++----------- 1 file changed, 150 insertions(+), 102 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6e031ea..15a7759 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -120,7 +120,7 @@ class AscendAttentionState(Enum): @dataclass class AscendMetadata: - # **************************** Basic Properties **************************** + # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -138,7 +138,7 @@ class AscendMetadata: # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None - # ********************** KV Cache Related Properties *********************** + # ********************** KV Cache Related Properties ********************* # # Block addresses per sequence (Seq id -> list of physical block). # (batch_size, max_blocks_per_seq) block_tables: torch.Tensor = None @@ -150,6 +150,7 @@ class AscendMetadata: # (num_tokens,) slot_mapping: torch.Tensor = None + # *************************** Other Properties *************************** # enable_dbo_across_dp: bool = False is_only_prefill: bool = False @@ -245,6 +246,144 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache = None self.value_cache = None + def _forward_prefill_no_cache( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + num_tokens=0, + ) -> torch.Tensor: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + + mask = attn_metadata.attn_mask + + if is_310p(): + # align q k v output tensors + query = aligned_16(query) + key = aligned_16(key) + value = aligned_16(value) + output = aligned_16(output) + # do reformat in case of broadcasted tensors + mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) + mask = torch_npu.npu_format_cast(mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + assert output is not None + return output[:num_tokens, :, :] + + def _forward_prefill_cache_hit( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + + compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_table, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + return output + + def _forward_decode_only( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if is_310p(): + # seq_lens_tensor needs to be transferred to the device for 310P. + attn_metadata.seq_lens = \ + attn_metadata.seq_lens.to(device=query.device) + + torch_npu._npu_paged_attention(query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + return output + + def _forward_v1_style( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Use chunked prefill for head size 192 scenario, like deepseek + # paged_attention_splitfuse maybe crash at such scenario. + # TODO: vanilla path will be removed after the kernel support + # head_size 192 scenario. + if self.head_size == 192: + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device) + cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device) + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) + max_seqlen_q = torch.max(attn_metadata.query_lens) + max_seqlen_k = torch.max(attn_metadata.seq_lens) + vanilla_chunked_prefill(output, query, self.key_cache, + self.value_cache, + attn_metadata.block_tables, cu_seqlen_q, + cu_seqlen_k, max_seqlen_q, max_seqlen_k, + self.scale, None, True) + return output + + # Use paged attention. + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + + if is_310p(): + # Do reformat in case of broadcasted tensors. + attn_metadata.attn_mask = \ + torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + attn_metadata.seq_lens = \ + attn_metadata.seq_lens.to(device=query.device) + + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + return output + def forward( self, layer: AttentionLayer, @@ -325,109 +464,18 @@ class AscendAttentionBackendImpl(AttentionImpl): # V0-Style scheduler situation. if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - mask = attn_metadata.attn_mask - if is_310p(): - # align q k v output tensors - query = aligned_16(query) - key = aligned_16(key) - value = aligned_16(value) - output = aligned_16(output) - - # do reformat in case of broadcasted tensors - mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) - mask = torch_npu.npu_format_cast(mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) - output = output[:num_tokens, :, :] - elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - compress_mask = attn_metadata.attn_mask - batch_size = attn_metadata.query_lens.shape[0] - block_table = attn_metadata.block_tables[:batch_size, :] - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + output = self._forward_prefill_cache_hit( + query, attn_metadata, output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - if is_310p(): - # # seq_lens_tensor needs to be transferred to the device for 310P - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + output = self._forward_decode_only(query, attn_metadata, + output) # Normal V1 situation. else: - # use chunked prefill for head size 192 scenario, like deepseek - # paged_attention_splitfuse maybe crash at such scenario - # TODO: vanilla path will be removed after the kernel support - # head_size 192 scenario - if self.head_size == 192: - cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() - cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, - device=query.device) - cu_seqlen_k = torch.tensor(cu_seqlen_k, - device=query.device) - cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) - cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) - max_seqlen_q = torch.max(attn_metadata.query_lens) - max_seqlen_k = torch.max(attn_metadata.seq_lens) - vanilla_chunked_prefill(output, query, self.key_cache, - self.value_cache, - attn_metadata.block_tables, - cu_seqlen_q, cu_seqlen_k, - max_seqlen_q, max_seqlen_k, - self.scale, None, True) - else: - # use paged attention - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - if is_310p(): - # do reformat in case of broadcasted tensors - attn_metadata.attn_mask = \ - torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ) - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) - torch_npu._npu_paged_attention_splitfuse( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + output = self._forward_v1_style(query, attn_metadata, output) # to make in-place change to the output tensor if hasattr(layer, 'quant_method') and use_kv_cache_int8: