[feature] fia support sliding windows (#5239)
Enable fia to support sliding window function and adapt to the Gemma3
model.
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
@@ -38,6 +38,7 @@ class AttentionMaskBuilder:
|
||||
self.mla_mask = None
|
||||
self.chunked_prefill_attn_mask = None
|
||||
self.pcp_mla_mask = None
|
||||
self.swa_mask = None
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
|
||||
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
|
||||
@@ -73,3 +74,12 @@ class AttentionMaskBuilder:
|
||||
self.pcp_mla_mask = torch.triu(
|
||||
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
||||
return self.pcp_mla_mask
|
||||
|
||||
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
|
||||
if self.swa_mask is None or self.swa_mask.dtype != dtype:
|
||||
if sliding_window is not None:
|
||||
mask = torch.ones(2048, 2048, dtype=torch.bool)
|
||||
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
|
||||
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
||||
self.swa_mask = triu_mask + tril_mask
|
||||
return self.swa_mask
|
||||
@@ -44,6 +44,9 @@ from vllm_ascend.compilation.acl_graph import (
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
weak_ref_tensors)
|
||||
|
||||
# default max value of sliding window size
|
||||
SWA_INT_MAX = 2147483647
|
||||
|
||||
|
||||
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
@@ -170,6 +173,9 @@ class AscendMetadata:
|
||||
# runner_type in model_config.
|
||||
model_runner_type: str = ""
|
||||
|
||||
# sliding window attention mask
|
||||
swa_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
@@ -234,6 +240,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
swa_mask = common_attn_metadata.swa_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
@@ -251,6 +258,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
swa_mask=swa_mask,
|
||||
attn_state=attn_state,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
@@ -549,7 +557,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tokens=self.sliding_window
|
||||
if self.sliding_window else SWA_INT_MAX,
|
||||
next_tokens=0 if self.sliding_window else SWA_INT_MAX,
|
||||
atten_mask=attn_metadata.swa_mask
|
||||
if self.sliding_window else attn_metadata.attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
@@ -558,7 +570,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
sparse_mode=4 if self.sliding_window else 3,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(num_tokens, self.num_heads,
|
||||
|
||||
@@ -142,6 +142,8 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
swa_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
@@ -175,6 +177,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
swa_mask=self.swa_mask,
|
||||
attn_state=self.attn_state,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
|
||||
@@ -247,6 +247,15 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
self._set_up_drafter()
|
||||
|
||||
# sliding window attn mask
|
||||
self.swa_mask = None
|
||||
is_swa = hasattr(self.vllm_config.model_config.hf_text_config,
|
||||
"sliding_window")
|
||||
if self.model_config is not None and is_swa:
|
||||
self.swa_mask = self.attn_mask_builder.get_swa_mask(
|
||||
self.dtype,
|
||||
self.vllm_config.model_config.hf_text_config.sliding_window)
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
self.is_kv_consumer = False
|
||||
@@ -1062,6 +1071,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
positions=self.positions.gpu,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
swa_mask=self.swa_mask,
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
@@ -1874,6 +1884,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
positions=self.positions.gpu,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
swa_mask=self.swa_mask,
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
|
||||
Reference in New Issue
Block a user