[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)

This commit is contained in:
Stefan He
2025-08-21 15:11:38 -07:00
committed by GitHub
parent 275f9df381
commit cded039b57

View File

@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
0, self.max_context_len, self.page_size, device=self.device
),
}
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None:
@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
)
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
max_num_pages,
dtype=torch.int32,
device=self.device,
),