[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)
This commit is contained in:
@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
to avoid memory allocations.
|
to avoid memory allocations.
|
||||||
"""
|
"""
|
||||||
|
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
||||||
|
|
||||||
# This is being used by normal decode and draft decode when topk == 1
|
# This is being used by normal decode and draft decode when topk == 1
|
||||||
self.decode_cuda_graph_metadata = {
|
self.decode_cuda_graph_metadata = {
|
||||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
"page_table": torch.zeros(
|
"page_table": torch.zeros(
|
||||||
max_bs,
|
max_bs,
|
||||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
max_num_pages,
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
"page_table_draft_decode": torch.zeros(
|
|
||||||
max_bs,
|
|
||||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
0, self.max_context_len, self.page_size, device=self.device
|
0, self.max_context_len, self.page_size, device=self.device
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only allocate local attention buffers if local attention is enabled
|
# Only allocate local attention buffers if local attention is enabled
|
||||||
# This prevents OOM errors when local attention is not being used
|
# This prevents OOM errors when local attention is not being used
|
||||||
if self.attention_chunk_size is not None:
|
if self.attention_chunk_size is not None:
|
||||||
@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.speculative_num_draft_tokens is not None
|
self.speculative_num_draft_tokens is not None
|
||||||
and self.speculative_num_draft_tokens > 0
|
and self.speculative_num_draft_tokens > 0
|
||||||
):
|
):
|
||||||
|
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
|
||||||
|
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
max_num_pages,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
self.target_verify_metadata = {
|
self.target_verify_metadata = {
|
||||||
"cache_seqlens": torch.zeros(
|
"cache_seqlens": torch.zeros(
|
||||||
max_bs, dtype=torch.int32, device=self.device
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
"page_table": torch.zeros(
|
"page_table": torch.zeros(
|
||||||
max_bs,
|
max_bs,
|
||||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
max_num_pages,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
"page_table": torch.zeros(
|
"page_table": torch.zeros(
|
||||||
max_bs,
|
max_bs,
|
||||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
max_num_pages,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user