From 4da46da9bf8df06c003cf65bd40ced3b4a2e8002 Mon Sep 17 00:00:00 2001 From: yeyifan Date: Mon, 29 Dec 2025 14:56:25 +0800 Subject: [PATCH] [feature] fia support sliding windows (#5239) Enable fia to support sliding window function and adapt to the Gemma3 model. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: nsdie --- vllm_ascend/attention/attention_mask.py | 10 ++++++++++ vllm_ascend/attention/attention_v1.py | 16 ++++++++++++++-- vllm_ascend/attention/utils.py | 3 +++ vllm_ascend/worker/model_runner_v1.py | 11 +++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index a291a480..5bdfbd92 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -38,6 +38,7 @@ class AttentionMaskBuilder: self.mla_mask = None self.chunked_prefill_attn_mask = None self.pcp_mla_mask = None + self.swa_mask = None def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached: @@ -73,3 +74,12 @@ class AttentionMaskBuilder: self.pcp_mla_mask = torch.triu( torch.ones(512, 512, device=self.device, dtype=dtype), 1) return self.pcp_mla_mask + + def get_swa_mask(self, dtype: torch.dtype, sliding_window): + if self.swa_mask is None or self.swa_mask.dtype != dtype: + if sliding_window is not None: + mask = torch.ones(2048, 2048, dtype=torch.bool) + triu_mask = torch.triu(mask, diagonal=1).to(self.device) + tril_mask = torch.tril(mask, -sliding_window).to(self.device) + self.swa_mask = triu_mask + tril_mask + return self.swa_mask \ No newline at end of file diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 854ac033..a87ed516 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -44,6 +44,9 @@ from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, weak_ref_tensors) +# default max value of sliding window size +SWA_INT_MAX = 2147483647 + @register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") class AscendAttentionBackend(AttentionBackend): @@ -170,6 +173,9 @@ class AscendMetadata: # runner_type in model_config. model_runner_type: str = "" + # sliding window attention mask + swa_mask: Optional[torch.Tensor] = None + class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). @@ -234,6 +240,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask + swa_mask = common_attn_metadata.swa_mask attn_state = common_attn_metadata.attn_state # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device @@ -251,6 +258,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, + swa_mask=swa_mask, attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, @@ -549,7 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl): query=query, key=key, value=value, - atten_mask=attn_metadata.attn_mask, + pre_tokens=self.sliding_window + if self.sliding_window else SWA_INT_MAX, + next_tokens=0 if self.sliding_window else SWA_INT_MAX, + atten_mask=attn_metadata.swa_mask + if self.sliding_window else attn_metadata.attn_mask, block_table=block_table, input_layout="TND", block_size=block_size, @@ -558,7 +570,7 @@ class AscendAttentionBackendImpl(AttentionImpl): num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, scale=self.scale, - sparse_mode=3, + sparse_mode=4 if self.sliding_window else 3, ) attn_output = attn_output.view(num_tokens, self.num_heads, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index db0cc99d..ae5b5733 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -142,6 +142,8 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): spec_attn_mask: torch.Tensor = None + swa_mask: torch.Tensor = None + attn_state: Any = None graph_pad_size: int = -1 @@ -175,6 +177,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): positions=self.positions[:num_actual_tokens], attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, + swa_mask=self.swa_mask, attn_state=self.attn_state, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=num_actual_tokens, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7a4ecac5..098418d4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -247,6 +247,15 @@ class NPUModelRunner(GPUModelRunner): self._set_up_drafter() + # sliding window attn mask + self.swa_mask = None + is_swa = hasattr(self.vllm_config.model_config.hf_text_config, + "sliding_window") + if self.model_config is not None and is_swa: + self.swa_mask = self.attn_mask_builder.get_swa_mask( + self.dtype, + self.vllm_config.model_config.hf_text_config.sliding_window) + # kv role self.is_kv_producer = False self.is_kv_consumer = False @@ -1062,6 +1071,7 @@ class NPUModelRunner(GPUModelRunner): positions=self.positions.gpu, attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, + swa_mask=self.swa_mask, attn_state=self.attn_state, max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, @@ -1874,6 +1884,7 @@ class NPUModelRunner(GPUModelRunner): positions=self.positions.gpu, attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, + swa_mask=self.swa_mask, attn_state=self.attn_state, max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req,