[feature] fia support sliding windows (#5239)

Enable fia to support sliding window function and adapt to the Gemma3
model.

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
yeyifan
2025-12-29 14:56:25 +08:00
committed by GitHub
parent d8e15dae6c
commit 4da46da9bf
4 changed files with 38 additions and 2 deletions

View File

@@ -38,6 +38,7 @@ class AttentionMaskBuilder:
self.mla_mask = None
self.chunked_prefill_attn_mask = None
self.pcp_mla_mask = None
self.swa_mask = None
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
@@ -73,3 +74,12 @@ class AttentionMaskBuilder:
self.pcp_mla_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
return self.pcp_mla_mask
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
if self.swa_mask is None or self.swa_mask.dtype != dtype:
if sliding_window is not None:
mask = torch.ones(2048, 2048, dtype=torch.bool)
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
self.swa_mask = triu_mask + tril_mask
return self.swa_mask

View File

@@ -44,6 +44,9 @@ from vllm_ascend.compilation.acl_graph import (
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
weak_ref_tensors)
# default max value of sliding window size
SWA_INT_MAX = 2147483647
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
class AscendAttentionBackend(AttentionBackend):
@@ -170,6 +173,9 @@ class AscendMetadata:
# runner_type in model_config.
model_runner_type: str = ""
# sliding window attention mask
swa_mask: Optional[torch.Tensor] = None
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
@@ -234,6 +240,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
swa_mask = common_attn_metadata.swa_mask
attn_state = common_attn_metadata.attn_state
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
@@ -251,6 +258,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
swa_mask=swa_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
@@ -549,7 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
pre_tokens=self.sliding_window
if self.sliding_window else SWA_INT_MAX,
next_tokens=0 if self.sliding_window else SWA_INT_MAX,
atten_mask=attn_metadata.swa_mask
if self.sliding_window else attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
@@ -558,7 +570,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
sparse_mode=4 if self.sliding_window else 3,
)
attn_output = attn_output.view(num_tokens, self.num_heads,

View File

@@ -142,6 +142,8 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
spec_attn_mask: torch.Tensor = None
swa_mask: torch.Tensor = None
attn_state: Any = None
graph_pad_size: int = -1
@@ -175,6 +177,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,

View File

@@ -247,6 +247,15 @@ class NPUModelRunner(GPUModelRunner):
self._set_up_drafter()
# sliding window attn mask
self.swa_mask = None
is_swa = hasattr(self.vllm_config.model_config.hf_text_config,
"sliding_window")
if self.model_config is not None and is_swa:
self.swa_mask = self.attn_mask_builder.get_swa_mask(
self.dtype,
self.vllm_config.model_config.hf_text_config.sliding_window)
# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
@@ -1062,6 +1071,7 @@ class NPUModelRunner(GPUModelRunner):
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
@@ -1874,6 +1884,7 @@ class NPUModelRunner(GPUModelRunner):
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,