Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)

Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
Chang Su
2025-04-11 00:34:17 -07:00
committed by GitHub
parent cd7e32e2cb
commit aee62d744b

View File

@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
page_indices = page_indices[req_pool_indices] // self.page_size
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)