Remove unnecessary metadata_expand.max_seq_len_k operations in fa3 to… (#7140)
This commit is contained in:
@@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
metadata_expand.max_seq_len_q = 1
|
metadata_expand.max_seq_len_q = 1
|
||||||
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
|
||||||
metadata_expand.cu_seqlens_q = torch.arange(
|
metadata_expand.cu_seqlens_q = torch.arange(
|
||||||
0,
|
0,
|
||||||
metadata_expand.cache_seqlens_int32.numel() + 1,
|
metadata_expand.cache_seqlens_int32.numel() + 1,
|
||||||
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
(1, 0),
|
(1, 0),
|
||||||
)
|
)
|
||||||
metadata_expand.max_seq_len_k = (
|
|
||||||
metadata_expand.cache_seqlens_int32.max().item()
|
|
||||||
)
|
|
||||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
metadata_expand.max_seq_len_q = 1
|
metadata_expand.max_seq_len_q = 1
|
||||||
metadata_expand.max_seq_len_k = (
|
|
||||||
self.speculative_step_id + 1
|
|
||||||
) # , do this in replay
|
|
||||||
metadata_expand.cu_seqlens_q = (
|
metadata_expand.cu_seqlens_q = (
|
||||||
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
||||||
: bs * self.topk + 1
|
: bs * self.topk + 1
|
||||||
@@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
metadata_expand.max_seq_len_k = (
|
|
||||||
metadata_expand.cache_seqlens_int32.max().item()
|
|
||||||
)
|
|
||||||
elif forward_mode.is_draft_extend():
|
elif forward_mode.is_draft_extend():
|
||||||
metadata = self.draft_extend_metadata[bs]
|
metadata = self.draft_extend_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
|
|||||||
Reference in New Issue
Block a user