diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 860994913..d65104beb 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module): Returns: [b * s, h, head_size] """ + if self.flatten_batch: + assert bsz == 1, "flatten_batch is True, bsz must be 1" s = q.shape[0] // bsz