Add an assertion to enhance the robustness of the operator (#5736)
This commit is contained in:
@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
if self.flatten_batch:
|
||||
assert bsz == 1, "flatten_batch is True, bsz must be 1"
|
||||
|
||||
s = q.shape[0] // bsz
|
||||
|
||||
|
||||
Reference in New Issue
Block a user