Fix fa3 cuda graph page_size > 1 precision and page_size=1 speed (#4855)
This commit is contained in:
@@ -322,10 +322,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
metadata.page_table = self.req_to_token[
|
||||
:, self.decode_cuda_graph_metadata["strided_indices"]
|
||||
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
||||
]
|
||||
metadata.page_table = metadata.page_table[req_pool_indices[:bs]]
|
||||
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
|
||||
Reference in New Issue
Block a user