[Feat]attention add sliding windows size (#2528)

### What this PR does / why we need it?
Add a sliding window size parameter to attention
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
Regarding the `Gemma3` model, set
additional_config={"ascend_scheduler_config": {"enabled":True}}, only
support AscendScheduler
test commond:`python3 -m vllm.entrypoints.openai.api_server --model
gemma3 --additional-config
'{"ascend_scheduler_config":{"enabled":true}}'`


- vLLM version: v0.10.1.1
- vLLM main:
6578e87365

---------

Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
yeyifan
2025-08-28 10:37:19 +08:00
committed by GitHub
parent c8d1df3a3f
commit 1191a64ae5
2 changed files with 149 additions and 18 deletions

View File

@@ -265,6 +265,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache = None
self.value_cache = None
def _repeat_kv(self, hidden_states: torch.Tensor,
n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, None, :, :].expand(
num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
head_dim)
def _forward_prefill_no_cache(
self,
query: torch.Tensor,
@@ -290,15 +304,34 @@ class AscendAttentionBackendImpl(AttentionImpl):
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
if self.sliding_window is not None and \
attn_metadata.attn_mask.shape[0] > self.sliding_window:
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
pre_tokens=self.sliding_window,
scale=self.scale,
actual_seq_lengths=attn_metadata.seq_lens,
actual_seq_lengths_kv=attn_metadata.seq_lens)
output = output.view(num_tokens, self.num_heads, self.head_size)
else:
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
assert output is not None
return output[:num_tokens, :, :]
@@ -339,16 +372,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
# seq_lens_tensor needs to be transferred to the device for 310P.
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
if self.sliding_window is not None:
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
key = self.key_cache
value = self.value_cache
if self.key_cache is not None and self.value_cache is not None:
block_size = self.key_cache.shape[1]
key = self.key_cache.flatten(2, 3).contiguous()
value = self.value_cache.flatten(2, 3).contiguous()
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSH",
block_size=block_size,
pre_tokens=self.sliding_window,
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens)
output = output.view(batch_size, self.num_heads, self.head_size)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output
def _forward_v1_style(