Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
:,
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
||||
None, :
|
||||
],
|
||||
]
|
||||
page_indices = page_indices[req_pool_indices] // self.page_size
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user