[Feat]attention add sliding windows size (#2528)
### What this PR does / why we need it?
Add a sliding window size parameter to attention
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Regarding the `Gemma3` model, set
additional_config={"ascend_scheduler_config": {"enabled":True}}, only
support AscendScheduler
test commond:`python3 -m vllm.entrypoints.openai.api_server --model
gemma3 --additional-config
'{"ascend_scheduler_config":{"enabled":true}}'`
- vLLM version: v0.10.1.1
- vLLM main:
6578e87365
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
@@ -228,6 +228,18 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
attn_type=None,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl_swa = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=1024,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
||||
"""Test forward pass when trace_flag is True"""
|
||||
@@ -329,6 +341,36 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention_qlens')
|
||||
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
||||
@@ -387,6 +429,35 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10] * 10)
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 100
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
print(output.shape)
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
|
||||
@@ -265,6 +265,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor,
|
||||
n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, None, :, :].expand(
|
||||
num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
|
||||
head_dim)
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -290,15 +304,34 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
if self.sliding_window is not None and \
|
||||
attn_metadata.attn_mask.shape[0] > self.sliding_window:
|
||||
|
||||
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
|
||||
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens, :, :]
|
||||
|
||||
@@ -339,16 +372,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
if self.sliding_window is not None:
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
key = self.key_cache
|
||||
value = self.value_cache
|
||||
if self.key_cache is not None and self.value_cache is not None:
|
||||
block_size = self.key_cache.shape[1]
|
||||
key = self.key_cache.flatten(2, 3).contiguous()
|
||||
value = self.value_cache.flatten(2, 3).contiguous()
|
||||
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSH",
|
||||
block_size=block_size,
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
|
||||
Reference in New Issue
Block a user