From 1acca3a2c685221cdb181c2abda4f635e1ead435 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Fri, 2 May 2025 00:26:12 -0700 Subject: [PATCH] FA3 speed up: skip len operation and get batch size directly from forward batch (#5969) Signed-off-by: Lifu Huang --- python/sglang/srt/layers/attention/flashattention_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9579b19f2..c148ac159 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() seqlens_in_batch = forward_batch.seq_lens - batch_size = len(seqlens_in_batch) + batch_size = forward_batch.batch_size device = seqlens_in_batch.device if forward_batch.forward_mode.is_decode_or_idle():