diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 556c8d7..b742897 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -341,36 +341,6 @@ class TestAscendAttentionBackendImpl(TestBase): mock_flash_attention.assert_called_once() assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_flash_attention') - def test_forward_prefill_no_cache_swa(self, mock_flash_attention, - mock_reshape_cache): - """Test forward pass in PrefillNoCache state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.PrefillNoCache - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.seq_lens = torch.tensor([10]) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - layer = self.layer_no_quant - # layer.quant_method.apply.return_value = metadata - print(self.layer_no_quant._v_scale_float) - output = self.impl_swa.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) - - mock_reshape_cache.assert_called_once() - mock_flash_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_flash_attention_qlens') def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5460b94..0915cc3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -265,20 +265,6 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache = None self.value_cache = None - def _repeat_kv(self, hidden_states: torch.Tensor, - n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, None, :, :].expand( - num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(num_key_value_heads * n_rep, slen, - head_dim) - def _forward_prefill_no_cache( self, query: torch.Tensor, @@ -304,34 +290,15 @@ class AscendAttentionBackendImpl(AttentionImpl): mask = torch_npu.npu_format_cast(mask.contiguous(), ACL_FORMAT_FRACTAL_NZ) - if self.sliding_window is not None and \ - attn_metadata.attn_mask.shape[0] > self.sliding_window: - - key = self._repeat_kv(key, self.num_heads // self.num_kv_heads) - value = self._repeat_kv(value, self.num_heads // self.num_kv_heads) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", - pre_tokens=self.sliding_window, - scale=self.scale, - actual_seq_lengths=attn_metadata.seq_lens, - actual_seq_lengths_kv=attn_metadata.seq_lens) - output = output.view(num_tokens, self.num_heads, self.head_size) - else: - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) assert output is not None return output[:num_tokens, :, :]