FA3 speed up: skip len operation and get batch size directly from forward batch (#5969)
Signed-off-by: Lifu Huang <lifu.hlf@gmail.com>
This commit is contained in:
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
seqlens_in_batch = forward_batch.seq_lens
|
||||
batch_size = len(seqlens_in_batch)
|
||||
batch_size = forward_batch.batch_size
|
||||
device = seqlens_in_batch.device
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
|
||||
Reference in New Issue
Block a user