Remove unnecessary metadata_expand.max_seq_len_k operations in fa3 to… (#7140)

This commit is contained in:
Binyao Jiang
2025-06-12 23:25:52 -07:00
committed by GitHub
parent b02df20a8d
commit 22a6b9fc05

View File

@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
),
(1, 0),
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
self.forward_metadata_spec_decode_expand = metadata_expand
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = (
self.speculative_step_id + 1
) # , do this in replay
metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1
@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32,
)
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)