Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)

Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
Chang Su
2025-04-11 00:34:17 -07:00
committed by GitHub
parent cd7e32e2cb
commit aee62d744b

View File

@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k + self.page_size - 1 metadata.max_seq_len_k + self.page_size - 1
) // self.page_size ) // self.page_size
page_indices = self.req_to_token[ page_indices = self.req_to_token[
:, req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
] ]
page_indices = page_indices[req_pool_indices] // self.page_size page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0) metadata.page_table[:, max_seq_pages:].fill_(0)