[fix] prefill unsupport sliding window attention (#2758)
### What this PR does / why we need it?
fix prefill attention bug,not support sliding window.
npu_fused_infer_attention_score head_dim only equal 128, not support
other number.
### Does this PR introduce _any_ user-facing change?
remove prefill phase npu_fused_infer_attention_score
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
e599e2c65e
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
@@ -341,36 +341,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_flash_attention.assert_called_once()
|
mock_flash_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_flash_attention')
|
|
||||||
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
|
|
||||||
mock_reshape_cache):
|
|
||||||
"""Test forward pass in PrefillNoCache state"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
# layer.quant_method.apply.return_value = metadata
|
|
||||||
print(self.layer_no_quant._v_scale_float)
|
|
||||||
output = self.impl_swa.forward(layer,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_reshape_cache.assert_called_once()
|
|
||||||
mock_flash_attention.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu._npu_flash_attention_qlens')
|
@patch('torch_npu._npu_flash_attention_qlens')
|
||||||
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
||||||
|
|||||||
@@ -265,20 +265,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
|
|
||||||
def _repeat_kv(self, hidden_states: torch.Tensor,
|
|
||||||
n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, None, :, :].expand(
|
|
||||||
num_key_value_heads, n_rep, slen, head_dim)
|
|
||||||
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
|
|
||||||
head_dim)
|
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill_no_cache(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -304,34 +290,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
if self.sliding_window is not None and \
|
torch_npu._npu_flash_attention(query=query,
|
||||||
attn_metadata.attn_mask.shape[0] > self.sliding_window:
|
key=key,
|
||||||
|
value=value,
|
||||||
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
|
mask=mask,
|
||||||
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
|
seq_len=attn_metadata.seq_lens,
|
||||||
|
scale_value=self.scale,
|
||||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
num_heads=self.num_heads,
|
||||||
query,
|
num_kv_heads=self.num_kv_heads,
|
||||||
key,
|
out=output)
|
||||||
value,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
input_layout="TND",
|
|
||||||
pre_tokens=self.sliding_window,
|
|
||||||
scale=self.scale,
|
|
||||||
actual_seq_lengths=attn_metadata.seq_lens,
|
|
||||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
|
||||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
||||||
else:
|
|
||||||
torch_npu._npu_flash_attention(query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seq_len=attn_metadata.seq_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
out=output)
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
return output[:num_tokens, :, :]
|
return output[:num_tokens, :, :]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user