FA3 speed up: skip len operation and get batch size directly from forward batch (#5969)

Signed-off-by: Lifu Huang <lifu.hlf@gmail.com>
This commit is contained in:
Lifu Huang
2025-05-02 00:26:12 -07:00
committed by GitHub
parent 6ea1e6ac6e
commit 1acca3a2c6

View File

@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = len(seqlens_in_batch)
batch_size = forward_batch.batch_size
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():