diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index b1da7232..5a94b82b 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -68,6 +68,8 @@ class AttentionMaskBuilder: def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): + if max_seq_len == 2048 and torch.version.cann.startswith("8.3"): + return self.chunked_prefill_attn_mask.to(torch.bool) self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62bca309..07ef6d91 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -491,19 +491,44 @@ class AscendAttentionBackendImpl(AttentionImpl): compress_mask = attn_metadata.attn_mask batch_size = attn_metadata.query_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + if torch.version.cann.startswith("8.3") and block_size == 128: + # TODO:The npu_fused_infer_attention_score op is planned to + # be utilized in a wider range in upcoming versions. + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=compress_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + else: + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_table, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output def _forward_decode_only( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f30a9a39..bf0465c8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -962,8 +962,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 128, self.dtype, self.device) + if torch.version.cann.startswith("8.3"): + return self.attn_mask_builder.get_attn_mask( + 2048, self.dtype, self.device) + else: + return self.attn_mask_builder.get_attn_mask( + 128, self.dtype, self.device) # Decode-only situation. else: return None