refactor: minor refactors regarding multimodal processing (#6187)

This commit is contained in:
Mick
2025-05-18 13:53:20 +08:00
committed by GitHub
parent b3f3d610fd
commit 01dd39bac1
15 changed files with 140 additions and 98 deletions

View File

@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
flatten_batch: bool = False,
) -> Optional[torch.Tensor]:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
Args:
s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask