From aee62d744b08d83db8d7b55753b41cc9ebfb1155 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Fri, 11 Apr 2025 00:34:17 -0700 Subject: [PATCH] Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262) Co-authored-by: ch-wan --- .../sglang/srt/layers/attention/flashattention_backend.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9e6365cbf..2ab7f6cb5 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k + self.page_size - 1 ) // self.page_size page_indices = self.req_to_token[ - :, - self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][ + None, : + ], ] - page_indices = page_indices[req_pool_indices] // self.page_size + page_indices //= self.page_size metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, max_seq_pages:].fill_(0)