Add an assertion to enhance the robustness of the operator (#5736)

This commit is contained in:
liwenju0
2025-04-27 09:09:12 +08:00
committed by GitHub
parent 155890e4d1
commit 4d1e52abea

View File

@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
Returns:
[b * s, h, head_size]
"""
if self.flatten_batch:
assert bsz == 1, "flatten_batch is True, bsz must be 1"
s = q.shape[0] // bsz