Add an assertion to enhance the robustness of the operator (#5736)
This commit is contained in:
@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
[b * s, h, head_size]
|
[b * s, h, head_size]
|
||||||
"""
|
"""
|
||||||
|
if self.flatten_batch:
|
||||||
|
assert bsz == 1, "flatten_batch is True, bsz must be 1"
|
||||||
|
|
||||||
s = q.shape[0] // bsz
|
s = q.shape[0] // bsz
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user