From 7a9133014910fec69e3013b91d5d0d0a3f9b418e Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 3 Aug 2025 03:06:31 -0700 Subject: [PATCH] Save cuda graph memory for fa3 (#8567) --- .../layers/attention/flashattention_backend.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 740b46b6b..785cbf1d8 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend): ) metadata.page_table = self.decode_cuda_graph_metadata[ "page_table_draft_decode" - ][req_pool_indices, :] + ][:bs, :] self.decode_cuda_graph_metadata[bs] = metadata else: # When top k > 1, we need two specific draft decode metadata, and then merge states @@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend): ][: bs + 1] metadata.page_table = self.draft_decode_metadata_topk_normal[ "page_table" - ][req_pool_indices, :] + ][:bs, :] # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) metadata_expand.cache_seqlens_int32 = ( @@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k = seq_lens.max().item() # Precompute page table metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ - req_pool_indices, : + :bs, : ] # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( @@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend): : (bs + 1) ] - metadata.page_table = self.target_verify_metadata["page_table"][ - req_pool_indices, : - ] + metadata.page_table = self.target_verify_metadata["page_table"][:bs, :] self.target_verify_metadata[bs] = metadata else: @@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend): ][: bs + 1] metadata.page_table = self.target_verify_metadata_topk_normal[ "page_table" - ][req_pool_indices, :] + ][:bs, :] # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) metadata_expand.cache_seqlens_int32 = ( @@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][ : (bs + 1) ] - metadata.page_table = self.draft_extend_metadata["page_table"][ - req_pool_indices, : - ] + metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :] self.draft_extend_metadata[bs] = metadata @@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend): ][: (encoder_bs + 1)] metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][ - req_pool_indices, : + :bs, : ] self.forward_metadata = metadata