[Feat]attention add sliding windows size (#2528)

### What this PR does / why we need it?
Add a sliding window size parameter to attention
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
Regarding the `Gemma3` model, set
additional_config={"ascend_scheduler_config": {"enabled":True}}, only
support AscendScheduler
test commond:`python3 -m vllm.entrypoints.openai.api_server --model
gemma3 --additional-config
'{"ascend_scheduler_config":{"enabled":true}}'`


- vLLM version: v0.10.1.1
- vLLM main:
6578e87365

---------

Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
yeyifan
2025-08-28 10:37:19 +08:00
committed by GitHub
parent c8d1df3a3f
commit 1191a64ae5
2 changed files with 149 additions and 18 deletions

View File

@@ -228,6 +228,18 @@ class TestAscendAttentionBackendImpl(TestBase):
attn_type=None,
kv_sharing_target_layer_name=None)
self.impl_swa = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=1024,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
def test_forward_trace_flag_true(self, mock_unified_attention):
"""Test forward pass when trace_flag is True"""
@@ -329,6 +341,36 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
mock_reshape_cache):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention_qlens')
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
@@ -387,6 +429,35 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
mock_npu_reshape_and_cache):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')