From 99c1b9d2ee5c8b76434b37afa5e2a9f177de78e7 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 19 Feb 2025 20:16:48 +0800 Subject: [PATCH] fix: apply cache size limit of attention mask for VisionAttention (#3657) --- python/sglang/srt/layers/attention/vision.py | 122 +++++++++---------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 03c4cfb46..0377277e6 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from typing import Optional import torch @@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module): """ - # TODO: Should it be released after used? - _mask_cache = {} - def __init__( self, head_size: int, @@ -239,76 +237,62 @@ class VisionSdpaAttention(nn.Module): self.use_full_precision_softmax = use_full_precision_softmax self.dropout = dropout - def generate_patch_attention_mask( - self, - s: int, - bsz: int, - device, - cu_seqlens: Optional[torch.Tensor], - flatten_batch: bool = False, - dtype=torch.bfloat16, - ) -> torch.Tensor: - r""" - Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. - - When `flatten_batch` is True: - - All sequences in the batch are flattened into a single dimension - - `s` represents the total number of tokens across all sequences in the batch - - Returns a unified mask of shape `(1, 1, s, s)` - - When `flatten_batch` is False: - - Each sequence has its own attention mask - - `s` represents the maximum sequence length in the batch - - Returns separate masks of shape `(b, 1, s, s)` - - Args: - flatten_batch: (bool): - If True, treats all sequences in the batch as a single flattened sequence - If False, generates separate masks for each sequence - - Returns: - Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + @staticmethod + @lru_cache(maxsize=128) + def _generate_mask_cache( + s: int, flatten_batch: bool, cu_seqlens: tuple + ) -> torch.BoolTensor: + """ + Generate a boolean attention mask with caching mechanism. + Args: + s: sequence length + flatten_batch: whether to flatten batch dimension + cu_seqlens: tuple of cumulative sequence lengths + Returns: + attention mask tensor """ - - cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist())) - - if cache_key in VisionSdpaAttention._mask_cache: - cached_mask = VisionSdpaAttention._mask_cache[cache_key] - # print(f"cache hit for key: {cache_key}") - return cached_mask.to(device=device, dtype=dtype) - - if cu_seqlens is None: - raise ValueError("Internal Error: cu_seqlens cannot be None") - if flatten_batch: - mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) + mask = torch.zeros([1, s, s], dtype=torch.bool) for i in range(1, len(cu_seqlens)): start = cu_seqlens[i - 1] end = cu_seqlens[i] - mask[ - ..., - start:end, - start:end, - ] = True + mask[..., start:end, start:end] = True else: # [1, 1, 1, s] - row_indices = torch.arange(s, device=device).view(1, 1, 1, s) + row_indices = torch.arange(s).view(1, 1, 1, s) # [1, 1, s, 1] - col_indices = torch.arange(s, device=device).view(1, 1, s, 1) + col_indices = torch.arange(s).view(1, 1, s, 1) # [b, 1, 1, 1] - seq_lens = ( - (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) - ) + seq_lens = torch.tensor( + [end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])], + ).view(-1, 1, 1, 1) mask = (row_indices < seq_lens) & (col_indices < seq_lens) - # Convert to attention mask format (False -> 0, True -> -inf) - mask = (~mask).to(dtype) * torch.finfo(dtype).min - - VisionSdpaAttention._mask_cache[cache_key] = mask - return mask + def generate_patch_attention_mask( + self, + s: int, + cu_seqlens: Optional[torch.Tensor], + flatten_batch: bool = False, + ) -> Optional[torch.Tensor]: + r""" + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + Args: + s: sequence length + cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask + flatten_batch: whether to flatten batch dimension + Returns: + attention mask tensor or None + """ + if cu_seqlens is None: + return None + + cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist()) + + return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple) + def forward( self, q: torch.Tensor, @@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module): # [b, 1, s, s] if attention_mask is None: attention_mask = self.generate_patch_attention_mask( - s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype + s, cu_seqlens, flatten_batch=self.flatten_batch ) + + if attention_mask is None: + if self.use_full_precision_softmax: + raise RuntimeError("Empty attention mask") + else: + attention_mask = attention_mask.to(device=q.device) + q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] - # [b, 1, s] + if self.use_full_precision_softmax: scale = self.head_size**-0.5 k_transposed = rearrange(k, "b h s d -> b h d s") attn_weights = torch.matmul(q, k_transposed) * scale del k, k_transposed + attention_mask = (~attention_mask) * torch.finfo(q.dtype).min attn_weights = attn_weights + attention_mask del attention_mask # full-precision @@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module): # SDPA # [b, h, s, head_size] output = F.scaled_dot_product_attention( - q, k, v, attention_mask, dropout_p=self.dropout + q, + k, + v, + attn_mask=attention_mask, + dropout_p=self.dropout, + is_causal=False, ) # [b, h, s, head_size] --> [b * s, h, head_size] @@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module): v: torch.Tensor, _bsz: int, cu_seqlens: Optional[torch.Tensor], - **kwargs, ) -> torch.Tensor: r""" Args: