diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 8456cb8..2aa915c 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import numpy as np import torch +from torch.nn.functional import scaled_dot_product_attention try: import torch_npu # noqa: F401 @@ -715,6 +716,7 @@ class AscendAttentionBackendImpl(AttentionImpl): value = value.view(-1, self.num_kv_heads, self.head_size) # TODO: Remove this contiguous in the future. value = value.contiguous() + attn_type = self.attn_type output = torch.empty(num_tokens, self.num_heads, @@ -758,23 +760,50 @@ class AscendAttentionBackendImpl(AttentionImpl): if (attn_metadata.block_tables is None or attn_metadata.block_tables.numel() == 0): - assert attn_metadata.attn_mask is not None - mask = attn_metadata.attn_mask - assert attn_metadata.prefill_metadata is not None - self.seq_lens_tensor_cpu = torch.from_numpy( - np.array( - attn_metadata.prefill_metadata.seq_lens).astype( - np.int32)) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=mask, - seq_len=self.seq_lens_tensor_cpu, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) + if attn_type == AttentionType.ENCODER_ONLY: + # TODO: change to use torch_npu encoder attention op, instead + # of torch sdpa + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + if attn_metadata.seq_lens is not None: + seq_lens_q = seq_lens_kv = attn_metadata.seq_lens + attn_masks = [None] * len(seq_lens_q) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip( + seq_lens_q, seq_lens_kv, attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim( + query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + else: + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + assert attn_metadata.prefill_metadata is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.prefill_metadata.seq_lens). + astype(np.int32)) + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=mask, + seq_len=self.seq_lens_tensor_cpu, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) else: # TODO: Will support prefix cache and chunked prefill soon. raise RuntimeError(