diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9579b19f2..c148ac159 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() seqlens_in_batch = forward_batch.seq_lens - batch_size = len(seqlens_in_batch) + batch_size = forward_batch.batch_size device = seqlens_in_batch.device if forward_batch.forward_mode.is_decode_or_idle():