diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 2d4e4b263..50e952e22 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ + max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size + # This is being used by normal decode and draft decode when topk == 1 self.decode_cuda_graph_metadata = { "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), @@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend): ), "page_table": torch.zeros( max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, - dtype=torch.int32, - device=self.device, - ), - "page_table_draft_decode": torch.zeros( - max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, + max_num_pages, dtype=torch.int32, device=self.device, ), @@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend): 0, self.max_context_len, self.page_size, device=self.device ), } - # Only allocate local attention buffers if local attention is enabled # This prevents OOM errors when local attention is not being used if self.attention_chunk_size is not None: @@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend): self.speculative_num_draft_tokens is not None and self.speculative_num_draft_tokens > 0 ): + # "page_table_draft_decode" will be set only when spec decoding enabled to save memory + self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros( + max_bs, + max_num_pages, + dtype=torch.int32, + device=self.device, + ) + self.target_verify_metadata = { "cache_seqlens": torch.zeros( max_bs, dtype=torch.int32, device=self.device @@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend): ), "page_table": torch.zeros( max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, + max_num_pages, dtype=torch.int32, device=self.device, ), @@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend): ), "page_table": torch.zeros( max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, + max_num_pages, dtype=torch.int32, device=self.device, ),