fix: apply cache size limit of attention mask for VisionAttention (#3657)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Should it be released after used?
|
|
||||||
_mask_cache = {}
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
@@ -239,76 +237,62 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
self.use_full_precision_softmax = use_full_precision_softmax
|
self.use_full_precision_softmax = use_full_precision_softmax
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def generate_patch_attention_mask(
|
@staticmethod
|
||||||
self,
|
@lru_cache(maxsize=128)
|
||||||
s: int,
|
def _generate_mask_cache(
|
||||||
bsz: int,
|
s: int, flatten_batch: bool, cu_seqlens: tuple
|
||||||
device,
|
) -> torch.BoolTensor:
|
||||||
cu_seqlens: Optional[torch.Tensor],
|
"""
|
||||||
flatten_batch: bool = False,
|
Generate a boolean attention mask with caching mechanism.
|
||||||
dtype=torch.bfloat16,
|
Args:
|
||||||
) -> torch.Tensor:
|
s: sequence length
|
||||||
r"""
|
flatten_batch: whether to flatten batch dimension
|
||||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
cu_seqlens: tuple of cumulative sequence lengths
|
||||||
|
Returns:
|
||||||
When `flatten_batch` is True:
|
attention mask tensor
|
||||||
- All sequences in the batch are flattened into a single dimension
|
|
||||||
- `s` represents the total number of tokens across all sequences in the batch
|
|
||||||
- Returns a unified mask of shape `(1, 1, s, s)`
|
|
||||||
|
|
||||||
When `flatten_batch` is False:
|
|
||||||
- Each sequence has its own attention mask
|
|
||||||
- `s` represents the maximum sequence length in the batch
|
|
||||||
- Returns separate masks of shape `(b, 1, s, s)`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
flatten_batch: (bool):
|
|
||||||
If True, treats all sequences in the batch as a single flattened sequence
|
|
||||||
If False, generates separate masks for each sequence
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
|
||||||
|
|
||||||
if cache_key in VisionSdpaAttention._mask_cache:
|
|
||||||
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
|
||||||
# print(f"cache hit for key: {cache_key}")
|
|
||||||
return cached_mask.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if cu_seqlens is None:
|
|
||||||
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
|
||||||
|
|
||||||
if flatten_batch:
|
if flatten_batch:
|
||||||
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
start = cu_seqlens[i - 1]
|
start = cu_seqlens[i - 1]
|
||||||
end = cu_seqlens[i]
|
end = cu_seqlens[i]
|
||||||
mask[
|
mask[..., start:end, start:end] = True
|
||||||
...,
|
|
||||||
start:end,
|
|
||||||
start:end,
|
|
||||||
] = True
|
|
||||||
else:
|
else:
|
||||||
# [1, 1, 1, s]
|
# [1, 1, 1, s]
|
||||||
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
row_indices = torch.arange(s).view(1, 1, 1, s)
|
||||||
# [1, 1, s, 1]
|
# [1, 1, s, 1]
|
||||||
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
col_indices = torch.arange(s).view(1, 1, s, 1)
|
||||||
# [b, 1, 1, 1]
|
# [b, 1, 1, 1]
|
||||||
seq_lens = (
|
seq_lens = torch.tensor(
|
||||||
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
[end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
|
||||||
)
|
).view(-1, 1, 1, 1)
|
||||||
|
|
||||||
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
||||||
|
|
||||||
# Convert to attention mask format (False -> 0, True -> -inf)
|
|
||||||
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
|
||||||
|
|
||||||
VisionSdpaAttention._mask_cache[cache_key] = mask
|
|
||||||
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
def generate_patch_attention_mask(
|
||||||
|
self,
|
||||||
|
s: int,
|
||||||
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
|
flatten_batch: bool = False,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||||
|
Args:
|
||||||
|
s: sequence length
|
||||||
|
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
||||||
|
flatten_batch: whether to flatten batch dimension
|
||||||
|
Returns:
|
||||||
|
attention mask tensor or None
|
||||||
|
"""
|
||||||
|
if cu_seqlens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
|
||||||
|
|
||||||
|
return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
# [b, 1, s, s]
|
# [b, 1, s, s]
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.generate_patch_attention_mask(
|
attention_mask = self.generate_patch_attention_mask(
|
||||||
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
s, cu_seqlens, flatten_batch=self.flatten_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
if self.use_full_precision_softmax:
|
||||||
|
raise RuntimeError("Empty attention mask")
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(device=q.device)
|
||||||
|
|
||||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||||
# [b, 1, s]
|
|
||||||
if self.use_full_precision_softmax:
|
if self.use_full_precision_softmax:
|
||||||
scale = self.head_size**-0.5
|
scale = self.head_size**-0.5
|
||||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||||
del k, k_transposed
|
del k, k_transposed
|
||||||
|
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
del attention_mask
|
del attention_mask
|
||||||
# full-precision
|
# full-precision
|
||||||
@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
# SDPA
|
# SDPA
|
||||||
# [b, h, s, head_size]
|
# [b, h, s, head_size]
|
||||||
output = F.scaled_dot_product_attention(
|
output = F.scaled_dot_product_attention(
|
||||||
q, k, v, attention_mask, dropout_p=self.dropout
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout,
|
||||||
|
is_causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# [b, h, s, head_size] --> [b * s, h, head_size]
|
# [b, h, s, head_size] --> [b * s, h, head_size]
|
||||||
@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
_bsz: int,
|
_bsz: int,
|
||||||
cu_seqlens: Optional[torch.Tensor],
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user