diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 2f974ea9a..a626ff0d8 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1434,19 +1434,7 @@ class FlashAttentionBackend(AttentionBackend): self.decode_cuda_graph_metadata[bs] = metadata if self.attention_chunk_size is not None: - metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[ - "local_query_start_loc" - ], - local_seqused_k=self.decode_cuda_graph_local_attn_metadata[ - "local_seqused_k" - ], - local_block_table=self.decode_cuda_graph_local_attn_metadata[ - "local_block_table" - ], - local_max_query_len=1, - local_max_seq_len=1, - ) + self._update_local_attn_metadata_for_capture(metadata, batch_size) elif forward_mode.is_target_verify(): if self.topk <= 1: @@ -1807,6 +1795,62 @@ class FlashAttentionBackend(AttentionBackend): ) metadata.local_attn_metadata = local_metadata + def _update_local_attn_metadata_for_capture( + self, metadata: FlashAttentionMetadata, bs: int + ): + """Update local attention metadata during CUDA graph capture phase. + + This method calculates the exact buffer sizes needed for local attention metadata + during the CUDA graph capture phase, optimizing memory usage by creating views of + pre-allocated buffers with exactly the sizes needed. + """ + seq_lens_capture = metadata.cache_seqlens_int32 + max_seq_len = int(seq_lens_capture.max().item()) + page_table_capture = metadata.page_table + + cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy() + seqlens_np = seq_lens_capture.cpu().numpy() + ( + seqlens_q_local_np, + cu_seqlens_q_local_np, + seqlens_k_local_np, + block_table_local_np, + ) = make_local_attention_virtual_batches( + self.attention_chunk_size, + cu_seqlens_q_np, + seqlens_np, + page_table_capture, + self.page_size, + ) + + # Get exact dimensions from the calculation + q_len = len(cu_seqlens_q_local_np) + k_len = len(seqlens_k_local_np) + b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs + b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1 + + # Create views of the pre-allocated buffers with exactly these sizes + # This is the key optimization - we only use the memory we actually need + local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[ + "local_query_start_loc" + ][:q_len] + + local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][ + :k_len + ] + + local_block_table = self.decode_cuda_graph_local_attn_metadata[ + "local_block_table" + ][:b0, :b1] + + metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=local_block_table, + local_max_query_len=1, + local_max_seq_len=max_seq_len, + ) + def _update_local_attn_metadata_for_replay( self, metadata: FlashAttentionMetadata, bs: int ):