diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 94a34e9..54be152 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -435,6 +435,41 @@ class TestAscendAttentionBackendImpl(TestBase): mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + @patch('torch_npu.npu_fused_infer_attention_score') + def test_forward_decode_only_swa_seq_len_mismatch( + self, mock_fused_infer_attention_score, mock_paged_attention, + mock_npu_reshape_and_cache): + """Test forward pass in DecodeOnly state when seq)len_mismatch""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) # len == 1 != query.size(0)==10 + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + + mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, + 64), 1) + + output = self.impl_swa.forward(self.layer_no_quant, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + mock_fused_infer_attention_score.assert_not_called() + + assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e0905fa..10a2f6a 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -378,7 +378,8 @@ class AscendAttentionBackendImpl(AttentionImpl): # seq_lens_tensor needs to be transferred to the device for 310P. attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) - if self.sliding_window is not None: + if self.sliding_window is not None and attn_metadata.seq_lens.shape[ + 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] block_size = 128 query = query.view(batch_size, 1, self.num_heads * self.head_size)