fix: apply cache size limit of attention mask for VisionAttention (#3657)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Should it be released after used?
|
||||
_mask_cache = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
@@ -239,76 +237,62 @@ class VisionSdpaAttention(nn.Module):
|
||||
self.use_full_precision_softmax = use_full_precision_softmax
|
||||
self.dropout = dropout
|
||||
|
||||
def generate_patch_attention_mask(
|
||||
self,
|
||||
s: int,
|
||||
bsz: int,
|
||||
device,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
flatten_batch: bool = False,
|
||||
dtype=torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
|
||||
When `flatten_batch` is True:
|
||||
- All sequences in the batch are flattened into a single dimension
|
||||
- `s` represents the total number of tokens across all sequences in the batch
|
||||
- Returns a unified mask of shape `(1, 1, s, s)`
|
||||
|
||||
When `flatten_batch` is False:
|
||||
- Each sequence has its own attention mask
|
||||
- `s` represents the maximum sequence length in the batch
|
||||
- Returns separate masks of shape `(b, 1, s, s)`
|
||||
|
||||
Args:
|
||||
flatten_batch: (bool):
|
||||
If True, treats all sequences in the batch as a single flattened sequence
|
||||
If False, generates separate masks for each sequence
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def _generate_mask_cache(
|
||||
s: int, flatten_batch: bool, cu_seqlens: tuple
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Generate a boolean attention mask with caching mechanism.
|
||||
Args:
|
||||
s: sequence length
|
||||
flatten_batch: whether to flatten batch dimension
|
||||
cu_seqlens: tuple of cumulative sequence lengths
|
||||
Returns:
|
||||
attention mask tensor
|
||||
"""
|
||||
|
||||
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
||||
|
||||
if cache_key in VisionSdpaAttention._mask_cache:
|
||||
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
||||
# print(f"cache hit for key: {cache_key}")
|
||||
return cached_mask.to(device=device, dtype=dtype)
|
||||
|
||||
if cu_seqlens is None:
|
||||
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
||||
|
||||
if flatten_batch:
|
||||
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
||||
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start = cu_seqlens[i - 1]
|
||||
end = cu_seqlens[i]
|
||||
mask[
|
||||
...,
|
||||
start:end,
|
||||
start:end,
|
||||
] = True
|
||||
mask[..., start:end, start:end] = True
|
||||
else:
|
||||
# [1, 1, 1, s]
|
||||
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
||||
row_indices = torch.arange(s).view(1, 1, 1, s)
|
||||
# [1, 1, s, 1]
|
||||
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
||||
col_indices = torch.arange(s).view(1, 1, s, 1)
|
||||
# [b, 1, 1, 1]
|
||||
seq_lens = (
|
||||
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
||||
)
|
||||
seq_lens = torch.tensor(
|
||||
[end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
|
||||
).view(-1, 1, 1, 1)
|
||||
|
||||
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
||||
|
||||
# Convert to attention mask format (False -> 0, True -> -inf)
|
||||
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
||||
|
||||
VisionSdpaAttention._mask_cache[cache_key] = mask
|
||||
|
||||
return mask
|
||||
|
||||
def generate_patch_attention_mask(
|
||||
self,
|
||||
s: int,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
flatten_batch: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
r"""
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
Args:
|
||||
s: sequence length
|
||||
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
||||
flatten_batch: whether to flatten batch dimension
|
||||
Returns:
|
||||
attention mask tensor or None
|
||||
"""
|
||||
if cu_seqlens is None:
|
||||
return None
|
||||
|
||||
cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
|
||||
|
||||
return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module):
|
||||
# [b, 1, s, s]
|
||||
if attention_mask is None:
|
||||
attention_mask = self.generate_patch_attention_mask(
|
||||
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
||||
s, cu_seqlens, flatten_batch=self.flatten_batch
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
if self.use_full_precision_softmax:
|
||||
raise RuntimeError("Empty attention mask")
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=q.device)
|
||||
|
||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||
# [b, 1, s]
|
||||
|
||||
if self.use_full_precision_softmax:
|
||||
scale = self.head_size**-0.5
|
||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||
del k, k_transposed
|
||||
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
||||
attn_weights = attn_weights + attention_mask
|
||||
del attention_mask
|
||||
# full-precision
|
||||
@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module):
|
||||
# SDPA
|
||||
# [b, h, s, head_size]
|
||||
output = F.scaled_dot_product_attention(
|
||||
q, k, v, attention_mask, dropout_p=self.dropout
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
# [b, h, s, head_size] --> [b * s, h, head_size]
|
||||
@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module):
|
||||
v: torch.Tensor,
|
||||
_bsz: int,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user