[feature] fia support sliding windows (#5239)
Enable fia to support sliding window function and adapt to the Gemma3
model.
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
This commit is contained in:
@@ -38,6 +38,7 @@ class AttentionMaskBuilder:
|
||||
self.mla_mask = None
|
||||
self.chunked_prefill_attn_mask = None
|
||||
self.pcp_mla_mask = None
|
||||
self.swa_mask = None
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
|
||||
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
|
||||
@@ -73,3 +74,12 @@ class AttentionMaskBuilder:
|
||||
self.pcp_mla_mask = torch.triu(
|
||||
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
||||
return self.pcp_mla_mask
|
||||
|
||||
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
|
||||
if self.swa_mask is None or self.swa_mask.dtype != dtype:
|
||||
if sliding_window is not None:
|
||||
mask = torch.ones(2048, 2048, dtype=torch.bool)
|
||||
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
|
||||
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
||||
self.swa_mask = triu_mask + tril_mask
|
||||
return self.swa_mask
|
||||
Reference in New Issue
Block a user