Optimize GPU memory usage in FlashAttentionBackend's strided indexing (#5262)
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k + self.page_size - 1
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
) // self.page_size
|
) // self.page_size
|
||||||
page_indices = self.req_to_token[
|
page_indices = self.req_to_token[
|
||||||
:,
|
req_pool_indices[:, None],
|
||||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
||||||
|
None, :
|
||||||
|
],
|
||||||
]
|
]
|
||||||
page_indices = page_indices[req_pool_indices] // self.page_size
|
page_indices //= self.page_size
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user