Save cuda graph memory for fa3 (#8567)
This commit is contained in:
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
metadata.page_table = self.decode_cuda_graph_metadata[
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||||
"page_table_draft_decode"
|
"page_table_draft_decode"
|
||||||
][req_pool_indices, :]
|
][:bs, :]
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
else:
|
else:
|
||||||
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||||
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
][: bs + 1]
|
][: bs + 1]
|
||||||
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
||||||
"page_table"
|
"page_table"
|
||||||
][req_pool_indices, :]
|
][:bs, :]
|
||||||
|
|
||||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
metadata_expand.cache_seqlens_int32 = (
|
metadata_expand.cache_seqlens_int32 = (
|
||||||
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k = seq_lens.max().item()
|
metadata.max_seq_len_k = seq_lens.max().item()
|
||||||
# Precompute page table
|
# Precompute page table
|
||||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||||
req_pool_indices, :
|
:bs, :
|
||||||
]
|
]
|
||||||
# Precompute cumulative sequence lengths
|
# Precompute cumulative sequence lengths
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
: (bs + 1)
|
: (bs + 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
metadata.page_table = self.target_verify_metadata["page_table"][
|
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
||||||
req_pool_indices, :
|
|
||||||
]
|
|
||||||
|
|
||||||
self.target_verify_metadata[bs] = metadata
|
self.target_verify_metadata[bs] = metadata
|
||||||
else:
|
else:
|
||||||
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
][: bs + 1]
|
][: bs + 1]
|
||||||
metadata.page_table = self.target_verify_metadata_topk_normal[
|
metadata.page_table = self.target_verify_metadata_topk_normal[
|
||||||
"page_table"
|
"page_table"
|
||||||
][req_pool_indices, :]
|
][:bs, :]
|
||||||
|
|
||||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
metadata_expand.cache_seqlens_int32 = (
|
metadata_expand.cache_seqlens_int32 = (
|
||||||
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
||||||
: (bs + 1)
|
: (bs + 1)
|
||||||
]
|
]
|
||||||
metadata.page_table = self.draft_extend_metadata["page_table"][
|
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
||||||
req_pool_indices, :
|
|
||||||
]
|
|
||||||
|
|
||||||
self.draft_extend_metadata[bs] = metadata
|
self.draft_extend_metadata[bs] = metadata
|
||||||
|
|
||||||
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
][: (encoder_bs + 1)]
|
][: (encoder_bs + 1)]
|
||||||
|
|
||||||
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
||||||
req_pool_indices, :
|
:bs, :
|
||||||
]
|
]
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|||||||
Reference in New Issue
Block a user