refactor: minor refactors regarding multimodal processing (#6187)
This commit is contained in:
@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
|
||||
flatten_batch: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
r"""
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
Args:
|
||||
s: sequence length
|
||||
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
||||
|
||||
Reference in New Issue
Block a user