FA3 speed up: skip len operation and get batch size directly from forward batch (#5969)
Signed-off-by: Lifu Huang <lifu.hlf@gmail.com>
This commit is contained in:
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
seqlens_in_batch = forward_batch.seq_lens
|
seqlens_in_batch = forward_batch.seq_lens
|
||||||
batch_size = len(seqlens_in_batch)
|
batch_size = forward_batch.batch_size
|
||||||
device = seqlens_in_batch.device
|
device = seqlens_in_batch.device
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
|||||||
Reference in New Issue
Block a user