[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)
This commit is contained in:
@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
||||
|
||||
# This is being used by normal decode and draft decode when topk == 1
|
||||
self.decode_cuda_graph_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"page_table_draft_decode": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
# Only allocate local attention buffers if local attention is enabled
|
||||
# This prevents OOM errors when local attention is not being used
|
||||
if self.attention_chunk_size is not None:
|
||||
@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.speculative_num_draft_tokens is not None
|
||||
and self.speculative_num_draft_tokens > 0
|
||||
):
|
||||
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
|
||||
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
||||
max_bs,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.target_verify_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user