Remove unnecessary metadata_expand.max_seq_len_k operations in fa3 to… (#7140)
This commit is contained in:
@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
||||
metadata_expand.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
metadata_expand.cache_seqlens_int32.numel() + 1,
|
||||
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
]
|
||||
)
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.max_seq_len_k = (
|
||||
self.speculative_step_id + 1
|
||||
) # , do this in replay
|
||||
metadata_expand.cu_seqlens_q = (
|
||||
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
||||
: bs * self.topk + 1
|
||||
@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user